...

Source file src/crypto/rsa/pss.go

     1	// Copyright 2013 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package rsa
     6	
     7	// This file implements the PSS signature scheme [1].
     8	//
     9	// [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
    10	
    11	import (
    12		"bytes"
    13		"crypto"
    14		"errors"
    15		"hash"
    16		"io"
    17		"math/big"
    18	)
    19	
    20	func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
    21		// See [1], section 9.1.1
    22		hLen := hash.Size()
    23		sLen := len(salt)
    24		emLen := (emBits + 7) / 8
    25	
    26		// 1.  If the length of M is greater than the input limitation for the
    27		//     hash function (2^61 - 1 octets for SHA-1), output "message too
    28		//     long" and stop.
    29		//
    30		// 2.  Let mHash = Hash(M), an octet string of length hLen.
    31	
    32		if len(mHash) != hLen {
    33			return nil, errors.New("crypto/rsa: input must be hashed message")
    34		}
    35	
    36		// 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
    37	
    38		if emLen < hLen+sLen+2 {
    39			return nil, errors.New("crypto/rsa: key size too small for PSS signature")
    40		}
    41	
    42		em := make([]byte, emLen)
    43		db := em[:emLen-sLen-hLen-2+1+sLen]
    44		h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
    45	
    46		// 4.  Generate a random octet string salt of length sLen; if sLen = 0,
    47		//     then salt is the empty string.
    48		//
    49		// 5.  Let
    50		//       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
    51		//
    52		//     M' is an octet string of length 8 + hLen + sLen with eight
    53		//     initial zero octets.
    54		//
    55		// 6.  Let H = Hash(M'), an octet string of length hLen.
    56	
    57		var prefix [8]byte
    58	
    59		hash.Write(prefix[:])
    60		hash.Write(mHash)
    61		hash.Write(salt)
    62	
    63		h = hash.Sum(h[:0])
    64		hash.Reset()
    65	
    66		// 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
    67		//     zero octets. The length of PS may be 0.
    68		//
    69		// 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
    70		//     emLen - hLen - 1.
    71	
    72		db[emLen-sLen-hLen-2] = 0x01
    73		copy(db[emLen-sLen-hLen-1:], salt)
    74	
    75		// 9.  Let dbMask = MGF(H, emLen - hLen - 1).
    76		//
    77		// 10. Let maskedDB = DB \xor dbMask.
    78	
    79		mgf1XOR(db, hash, h)
    80	
    81		// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
    82		//     maskedDB to zero.
    83	
    84		db[0] &= (0xFF >> uint(8*emLen-emBits))
    85	
    86		// 12. Let EM = maskedDB || H || 0xbc.
    87		em[emLen-1] = 0xBC
    88	
    89		// 13. Output EM.
    90		return em, nil
    91	}
    92	
    93	func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
    94		// 1.  If the length of M is greater than the input limitation for the
    95		//     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
    96		//     and stop.
    97		//
    98		// 2.  Let mHash = Hash(M), an octet string of length hLen.
    99		hLen := hash.Size()
   100		if hLen != len(mHash) {
   101			return ErrVerification
   102		}
   103	
   104		// 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
   105		emLen := (emBits + 7) / 8
   106		if emLen < hLen+sLen+2 {
   107			return ErrVerification
   108		}
   109	
   110		// 4.  If the rightmost octet of EM does not have hexadecimal value
   111		//     0xbc, output "inconsistent" and stop.
   112		if em[len(em)-1] != 0xBC {
   113			return ErrVerification
   114		}
   115	
   116		// 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
   117		//     let H be the next hLen octets.
   118		db := em[:emLen-hLen-1]
   119		h := em[emLen-hLen-1 : len(em)-1]
   120	
   121		// 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
   122		//     maskedDB are not all equal to zero, output "inconsistent" and
   123		//     stop.
   124		if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
   125			return ErrVerification
   126		}
   127	
   128		// 7.  Let dbMask = MGF(H, emLen - hLen - 1).
   129		//
   130		// 8.  Let DB = maskedDB \xor dbMask.
   131		mgf1XOR(db, hash, h)
   132	
   133		// 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
   134		//     to zero.
   135		db[0] &= (0xFF >> uint(8*emLen-emBits))
   136	
   137		if sLen == PSSSaltLengthAuto {
   138		FindSaltLength:
   139			for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
   140				switch db[emLen-hLen-sLen-2] {
   141				case 1:
   142					break FindSaltLength
   143				case 0:
   144					continue
   145				default:
   146					return ErrVerification
   147				}
   148			}
   149			if sLen < 0 {
   150				return ErrVerification
   151			}
   152		} else {
   153			// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
   154			//     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
   155			//     position is "position 1") does not have hexadecimal value 0x01,
   156			//     output "inconsistent" and stop.
   157			for _, e := range db[:emLen-hLen-sLen-2] {
   158				if e != 0x00 {
   159					return ErrVerification
   160				}
   161			}
   162			if db[emLen-hLen-sLen-2] != 0x01 {
   163				return ErrVerification
   164			}
   165		}
   166	
   167		// 11.  Let salt be the last sLen octets of DB.
   168		salt := db[len(db)-sLen:]
   169	
   170		// 12.  Let
   171		//          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
   172		//     M' is an octet string of length 8 + hLen + sLen with eight
   173		//     initial zero octets.
   174		//
   175		// 13. Let H' = Hash(M'), an octet string of length hLen.
   176		var prefix [8]byte
   177		hash.Write(prefix[:])
   178		hash.Write(mHash)
   179		hash.Write(salt)
   180	
   181		h0 := hash.Sum(nil)
   182	
   183		// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
   184		if !bytes.Equal(h0, h) {
   185			return ErrVerification
   186		}
   187		return nil
   188	}
   189	
   190	// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
   191	// Note that hashed must be the result of hashing the input message using the
   192	// given hash function. salt is a random sequence of bytes whose length will be
   193	// later used to verify the signature.
   194	func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
   195		nBits := priv.N.BitLen()
   196		em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
   197		if err != nil {
   198			return
   199		}
   200		m := new(big.Int).SetBytes(em)
   201		c, err := decryptAndCheck(rand, priv, m)
   202		if err != nil {
   203			return
   204		}
   205		s = make([]byte, (nBits+7)/8)
   206		copyWithLeftPad(s, c.Bytes())
   207		return
   208	}
   209	
   210	const (
   211		// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
   212		// as possible when signing, and to be auto-detected when verifying.
   213		PSSSaltLengthAuto = 0
   214		// PSSSaltLengthEqualsHash causes the salt length to equal the length
   215		// of the hash used in the signature.
   216		PSSSaltLengthEqualsHash = -1
   217	)
   218	
   219	// PSSOptions contains options for creating and verifying PSS signatures.
   220	type PSSOptions struct {
   221		// SaltLength controls the length of the salt used in the PSS
   222		// signature. It can either be a number of bytes, or one of the special
   223		// PSSSaltLength constants.
   224		SaltLength int
   225	
   226		// Hash, if not zero, overrides the hash function passed to SignPSS.
   227		// This is the only way to specify the hash function when using the
   228		// crypto.Signer interface.
   229		Hash crypto.Hash
   230	}
   231	
   232	// HashFunc returns pssOpts.Hash so that PSSOptions implements
   233	// crypto.SignerOpts.
   234	func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
   235		return pssOpts.Hash
   236	}
   237	
   238	func (opts *PSSOptions) saltLength() int {
   239		if opts == nil {
   240			return PSSSaltLengthAuto
   241		}
   242		return opts.SaltLength
   243	}
   244	
   245	// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
   246	// Note that hashed must be the result of hashing the input message using the
   247	// given hash function. The opts argument may be nil, in which case sensible
   248	// defaults are used.
   249	func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
   250		saltLength := opts.saltLength()
   251		switch saltLength {
   252		case PSSSaltLengthAuto:
   253			saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
   254		case PSSSaltLengthEqualsHash:
   255			saltLength = hash.Size()
   256		}
   257	
   258		if opts != nil && opts.Hash != 0 {
   259			hash = opts.Hash
   260		}
   261	
   262		salt := make([]byte, saltLength)
   263		if _, err := io.ReadFull(rand, salt); err != nil {
   264			return nil, err
   265		}
   266		return signPSSWithSalt(rand, priv, hash, hashed, salt)
   267	}
   268	
   269	// VerifyPSS verifies a PSS signature.
   270	// hashed is the result of hashing the input message using the given hash
   271	// function and sig is the signature. A valid signature is indicated by
   272	// returning a nil error. The opts argument may be nil, in which case sensible
   273	// defaults are used.
   274	func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
   275		return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
   276	}
   277	
   278	// verifyPSS verifies a PSS signature with the given salt length.
   279	func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
   280		nBits := pub.N.BitLen()
   281		if len(sig) != (nBits+7)/8 {
   282			return ErrVerification
   283		}
   284		s := new(big.Int).SetBytes(sig)
   285		m := encrypt(new(big.Int), pub, s)
   286		emBits := nBits - 1
   287		emLen := (emBits + 7) / 8
   288		if emLen < len(m.Bytes()) {
   289			return ErrVerification
   290		}
   291		em := make([]byte, emLen)
   292		copyWithLeftPad(em, m.Bytes())
   293		if saltLen == PSSSaltLengthEqualsHash {
   294			saltLen = hash.Size()
   295		}
   296		return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
   297	}
   298	

View as plain text