Source file src/crypto/rsa/pss.go
1
2
3
4
5 package rsa
6
7
8
9
10
11 import (
12 "bytes"
13 "crypto"
14 "errors"
15 "hash"
16 "io"
17 "math/big"
18 )
19
20 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
21
22 hLen := hash.Size()
23 sLen := len(salt)
24 emLen := (emBits + 7) / 8
25
26
27
28
29
30
31
32 if len(mHash) != hLen {
33 return nil, errors.New("crypto/rsa: input must be hashed message")
34 }
35
36
37
38 if emLen < hLen+sLen+2 {
39 return nil, errors.New("crypto/rsa: key size too small for PSS signature")
40 }
41
42 em := make([]byte, emLen)
43 db := em[:emLen-sLen-hLen-2+1+sLen]
44 h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
45
46
47
48
49
50
51
52
53
54
55
56
57 var prefix [8]byte
58
59 hash.Write(prefix[:])
60 hash.Write(mHash)
61 hash.Write(salt)
62
63 h = hash.Sum(h[:0])
64 hash.Reset()
65
66
67
68
69
70
71
72 db[emLen-sLen-hLen-2] = 0x01
73 copy(db[emLen-sLen-hLen-1:], salt)
74
75
76
77
78
79 mgf1XOR(db, hash, h)
80
81
82
83
84 db[0] &= (0xFF >> uint(8*emLen-emBits))
85
86
87 em[emLen-1] = 0xBC
88
89
90 return em, nil
91 }
92
93 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
94
95
96
97
98
99 hLen := hash.Size()
100 if hLen != len(mHash) {
101 return ErrVerification
102 }
103
104
105 emLen := (emBits + 7) / 8
106 if emLen < hLen+sLen+2 {
107 return ErrVerification
108 }
109
110
111
112 if em[len(em)-1] != 0xBC {
113 return ErrVerification
114 }
115
116
117
118 db := em[:emLen-hLen-1]
119 h := em[emLen-hLen-1 : len(em)-1]
120
121
122
123
124 if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
125 return ErrVerification
126 }
127
128
129
130
131 mgf1XOR(db, hash, h)
132
133
134
135 db[0] &= (0xFF >> uint(8*emLen-emBits))
136
137 if sLen == PSSSaltLengthAuto {
138 FindSaltLength:
139 for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
140 switch db[emLen-hLen-sLen-2] {
141 case 1:
142 break FindSaltLength
143 case 0:
144 continue
145 default:
146 return ErrVerification
147 }
148 }
149 if sLen < 0 {
150 return ErrVerification
151 }
152 } else {
153
154
155
156
157 for _, e := range db[:emLen-hLen-sLen-2] {
158 if e != 0x00 {
159 return ErrVerification
160 }
161 }
162 if db[emLen-hLen-sLen-2] != 0x01 {
163 return ErrVerification
164 }
165 }
166
167
168 salt := db[len(db)-sLen:]
169
170
171
172
173
174
175
176 var prefix [8]byte
177 hash.Write(prefix[:])
178 hash.Write(mHash)
179 hash.Write(salt)
180
181 h0 := hash.Sum(nil)
182
183
184 if !bytes.Equal(h0, h) {
185 return ErrVerification
186 }
187 return nil
188 }
189
190
191
192
193
194 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
195 nBits := priv.N.BitLen()
196 em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
197 if err != nil {
198 return
199 }
200 m := new(big.Int).SetBytes(em)
201 c, err := decryptAndCheck(rand, priv, m)
202 if err != nil {
203 return
204 }
205 s = make([]byte, (nBits+7)/8)
206 copyWithLeftPad(s, c.Bytes())
207 return
208 }
209
210 const (
211
212
213 PSSSaltLengthAuto = 0
214
215
216 PSSSaltLengthEqualsHash = -1
217 )
218
219
220 type PSSOptions struct {
221
222
223
224 SaltLength int
225
226
227
228
229 Hash crypto.Hash
230 }
231
232
233
234 func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
235 return pssOpts.Hash
236 }
237
238 func (opts *PSSOptions) saltLength() int {
239 if opts == nil {
240 return PSSSaltLengthAuto
241 }
242 return opts.SaltLength
243 }
244
245
246
247
248
249 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
250 saltLength := opts.saltLength()
251 switch saltLength {
252 case PSSSaltLengthAuto:
253 saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
254 case PSSSaltLengthEqualsHash:
255 saltLength = hash.Size()
256 }
257
258 if opts != nil && opts.Hash != 0 {
259 hash = opts.Hash
260 }
261
262 salt := make([]byte, saltLength)
263 if _, err := io.ReadFull(rand, salt); err != nil {
264 return nil, err
265 }
266 return signPSSWithSalt(rand, priv, hash, hashed, salt)
267 }
268
269
270
271
272
273
274 func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
275 return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
276 }
277
278
279 func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
280 nBits := pub.N.BitLen()
281 if len(sig) != (nBits+7)/8 {
282 return ErrVerification
283 }
284 s := new(big.Int).SetBytes(sig)
285 m := encrypt(new(big.Int), pub, s)
286 emBits := nBits - 1
287 emLen := (emBits + 7) / 8
288 if emLen < len(m.Bytes()) {
289 return ErrVerification
290 }
291 em := make([]byte, emLen)
292 copyWithLeftPad(em, m.Bytes())
293 if saltLen == PSSSaltLengthEqualsHash {
294 saltLen = hash.Size()
295 }
296 return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
297 }
298
View as plain text