Source file src/crypto/rsa/pss.go
     1	
     2	
     3	
     4	
     5	package rsa
     6	
     7	
     8	
     9	
    10	
    11	import (
    12		"bytes"
    13		"crypto"
    14		"errors"
    15		"hash"
    16		"io"
    17		"math/big"
    18	)
    19	
    20	func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
    21		
    22		hLen := hash.Size()
    23		sLen := len(salt)
    24		emLen := (emBits + 7) / 8
    25	
    26		
    27		
    28		
    29		
    30		
    31	
    32		if len(mHash) != hLen {
    33			return nil, errors.New("crypto/rsa: input must be hashed message")
    34		}
    35	
    36		
    37	
    38		if emLen < hLen+sLen+2 {
    39			return nil, errors.New("crypto/rsa: key size too small for PSS signature")
    40		}
    41	
    42		em := make([]byte, emLen)
    43		db := em[:emLen-sLen-hLen-2+1+sLen]
    44		h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
    45	
    46		
    47		
    48		
    49		
    50		
    51		
    52		
    53		
    54		
    55		
    56	
    57		var prefix [8]byte
    58	
    59		hash.Write(prefix[:])
    60		hash.Write(mHash)
    61		hash.Write(salt)
    62	
    63		h = hash.Sum(h[:0])
    64		hash.Reset()
    65	
    66		
    67		
    68		
    69		
    70		
    71	
    72		db[emLen-sLen-hLen-2] = 0x01
    73		copy(db[emLen-sLen-hLen-1:], salt)
    74	
    75		
    76		
    77		
    78	
    79		mgf1XOR(db, hash, h)
    80	
    81		
    82		
    83	
    84		db[0] &= (0xFF >> uint(8*emLen-emBits))
    85	
    86		
    87		em[emLen-1] = 0xBC
    88	
    89		
    90		return em, nil
    91	}
    92	
    93	func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
    94		
    95		
    96		
    97		
    98		
    99		hLen := hash.Size()
   100		if hLen != len(mHash) {
   101			return ErrVerification
   102		}
   103	
   104		
   105		emLen := (emBits + 7) / 8
   106		if emLen < hLen+sLen+2 {
   107			return ErrVerification
   108		}
   109	
   110		
   111		
   112		if em[len(em)-1] != 0xBC {
   113			return ErrVerification
   114		}
   115	
   116		
   117		
   118		db := em[:emLen-hLen-1]
   119		h := em[emLen-hLen-1 : len(em)-1]
   120	
   121		
   122		
   123		
   124		if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
   125			return ErrVerification
   126		}
   127	
   128		
   129		
   130		
   131		mgf1XOR(db, hash, h)
   132	
   133		
   134		
   135		db[0] &= (0xFF >> uint(8*emLen-emBits))
   136	
   137		if sLen == PSSSaltLengthAuto {
   138		FindSaltLength:
   139			for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
   140				switch db[emLen-hLen-sLen-2] {
   141				case 1:
   142					break FindSaltLength
   143				case 0:
   144					continue
   145				default:
   146					return ErrVerification
   147				}
   148			}
   149			if sLen < 0 {
   150				return ErrVerification
   151			}
   152		} else {
   153			
   154			
   155			
   156			
   157			for _, e := range db[:emLen-hLen-sLen-2] {
   158				if e != 0x00 {
   159					return ErrVerification
   160				}
   161			}
   162			if db[emLen-hLen-sLen-2] != 0x01 {
   163				return ErrVerification
   164			}
   165		}
   166	
   167		
   168		salt := db[len(db)-sLen:]
   169	
   170		
   171		
   172		
   173		
   174		
   175		
   176		var prefix [8]byte
   177		hash.Write(prefix[:])
   178		hash.Write(mHash)
   179		hash.Write(salt)
   180	
   181		h0 := hash.Sum(nil)
   182	
   183		
   184		if !bytes.Equal(h0, h) {
   185			return ErrVerification
   186		}
   187		return nil
   188	}
   189	
   190	
   191	
   192	
   193	
   194	func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
   195		nBits := priv.N.BitLen()
   196		em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
   197		if err != nil {
   198			return
   199		}
   200		m := new(big.Int).SetBytes(em)
   201		c, err := decryptAndCheck(rand, priv, m)
   202		if err != nil {
   203			return
   204		}
   205		s = make([]byte, (nBits+7)/8)
   206		copyWithLeftPad(s, c.Bytes())
   207		return
   208	}
   209	
   210	const (
   211		
   212		
   213		PSSSaltLengthAuto = 0
   214		
   215		
   216		PSSSaltLengthEqualsHash = -1
   217	)
   218	
   219	
   220	type PSSOptions struct {
   221		
   222		
   223		
   224		SaltLength int
   225	
   226		
   227		
   228		
   229		Hash crypto.Hash
   230	}
   231	
   232	
   233	
   234	func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
   235		return pssOpts.Hash
   236	}
   237	
   238	func (opts *PSSOptions) saltLength() int {
   239		if opts == nil {
   240			return PSSSaltLengthAuto
   241		}
   242		return opts.SaltLength
   243	}
   244	
   245	
   246	
   247	
   248	
   249	func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
   250		saltLength := opts.saltLength()
   251		switch saltLength {
   252		case PSSSaltLengthAuto:
   253			saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
   254		case PSSSaltLengthEqualsHash:
   255			saltLength = hash.Size()
   256		}
   257	
   258		if opts != nil && opts.Hash != 0 {
   259			hash = opts.Hash
   260		}
   261	
   262		salt := make([]byte, saltLength)
   263		if _, err := io.ReadFull(rand, salt); err != nil {
   264			return nil, err
   265		}
   266		return signPSSWithSalt(rand, priv, hash, hashed, salt)
   267	}
   268	
   269	
   270	
   271	
   272	
   273	
   274	func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
   275		return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
   276	}
   277	
   278	
   279	func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
   280		nBits := pub.N.BitLen()
   281		if len(sig) != (nBits+7)/8 {
   282			return ErrVerification
   283		}
   284		s := new(big.Int).SetBytes(sig)
   285		m := encrypt(new(big.Int), pub, s)
   286		emBits := nBits - 1
   287		emLen := (emBits + 7) / 8
   288		if emLen < len(m.Bytes()) {
   289			return ErrVerification
   290		}
   291		em := make([]byte, emLen)
   292		copyWithLeftPad(em, m.Bytes())
   293		if saltLen == PSSSaltLengthEqualsHash {
   294			saltLen = hash.Size()
   295		}
   296		return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
   297	}
   298	
View as plain text