Source file src/pkg/crypto/tls/conn.go
1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "fmt"
16 "io"
17 "net"
18 "sync"
19 "sync/atomic"
20 "time"
21 )
22
23
24
25 type Conn struct {
26
27 conn net.Conn
28 isClient bool
29
30
31
32
33 handshakeStatus uint32
34
35 handshakeMutex sync.Mutex
36 handshakeErr error
37 vers uint16
38 haveVers bool
39 config *Config
40
41
42
43 handshakes int
44 didResume bool
45 cipherSuite uint16
46 ocspResponse []byte
47 scts [][]byte
48 peerCertificates []*x509.Certificate
49
50
51 verifiedChains [][]*x509.Certificate
52
53 serverName string
54
55
56
57 secureRenegotiation bool
58
59 ekm func(label string, context []byte, length int) ([]byte, error)
60
61
62 resumptionSecret []byte
63
64
65
66
67
68 clientFinishedIsFirst bool
69
70
71 closeNotifyErr error
72
73
74 closeNotifySent bool
75
76
77
78
79
80 clientFinished [12]byte
81 serverFinished [12]byte
82
83 clientProtocol string
84 clientProtocolFallback bool
85
86
87 in, out halfConn
88 rawInput bytes.Buffer
89 input bytes.Reader
90 hand bytes.Buffer
91 outBuf []byte
92 buffering bool
93 sendBuf []byte
94
95
96
97 bytesSent int64
98 packetsSent int64
99
100
101
102
103 retryCount int
104
105
106
107
108 activeCall int32
109
110 tmp [16]byte
111 }
112
113
114
115
116
117
118 func (c *Conn) LocalAddr() net.Addr {
119 return c.conn.LocalAddr()
120 }
121
122
123 func (c *Conn) RemoteAddr() net.Addr {
124 return c.conn.RemoteAddr()
125 }
126
127
128
129
130 func (c *Conn) SetDeadline(t time.Time) error {
131 return c.conn.SetDeadline(t)
132 }
133
134
135
136 func (c *Conn) SetReadDeadline(t time.Time) error {
137 return c.conn.SetReadDeadline(t)
138 }
139
140
141
142
143 func (c *Conn) SetWriteDeadline(t time.Time) error {
144 return c.conn.SetWriteDeadline(t)
145 }
146
147
148
149 type halfConn struct {
150 sync.Mutex
151
152 err error
153 version uint16
154 cipher interface{}
155 mac macFunction
156 seq [8]byte
157 additionalData [13]byte
158
159 nextCipher interface{}
160 nextMac macFunction
161
162 trafficSecret []byte
163 }
164
165 func (hc *halfConn) setErrorLocked(err error) error {
166 hc.err = err
167 return err
168 }
169
170
171
172 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
173 hc.version = version
174 hc.nextCipher = cipher
175 hc.nextMac = mac
176 }
177
178
179
180 func (hc *halfConn) changeCipherSpec() error {
181 if hc.nextCipher == nil || hc.version == VersionTLS13 {
182 return alertInternalError
183 }
184 hc.cipher = hc.nextCipher
185 hc.mac = hc.nextMac
186 hc.nextCipher = nil
187 hc.nextMac = nil
188 for i := range hc.seq {
189 hc.seq[i] = 0
190 }
191 return nil
192 }
193
194 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
195 hc.trafficSecret = secret
196 key, iv := suite.trafficKey(secret)
197 hc.cipher = suite.aead(key, iv)
198 for i := range hc.seq {
199 hc.seq[i] = 0
200 }
201 }
202
203
204 func (hc *halfConn) incSeq() {
205 for i := 7; i >= 0; i-- {
206 hc.seq[i]++
207 if hc.seq[i] != 0 {
208 return
209 }
210 }
211
212
213
214
215 panic("TLS: sequence number wraparound")
216 }
217
218
219
220
221 func (hc *halfConn) explicitNonceLen() int {
222 if hc.cipher == nil {
223 return 0
224 }
225
226 switch c := hc.cipher.(type) {
227 case cipher.Stream:
228 return 0
229 case aead:
230 return c.explicitNonceLen()
231 case cbcMode:
232
233 if hc.version >= VersionTLS11 {
234 return c.BlockSize()
235 }
236 return 0
237 default:
238 panic("unknown cipher type")
239 }
240 }
241
242
243
244
245 func extractPadding(payload []byte) (toRemove int, good byte) {
246 if len(payload) < 1 {
247 return 0, 0
248 }
249
250 paddingLen := payload[len(payload)-1]
251 t := uint(len(payload)-1) - uint(paddingLen)
252
253 good = byte(int32(^t) >> 31)
254
255
256 toCheck := 256
257
258 if toCheck > len(payload) {
259 toCheck = len(payload)
260 }
261
262 for i := 0; i < toCheck; i++ {
263 t := uint(paddingLen) - uint(i)
264
265 mask := byte(int32(^t) >> 31)
266 b := payload[len(payload)-1-i]
267 good &^= mask&paddingLen ^ mask&b
268 }
269
270
271
272 good &= good << 4
273 good &= good << 2
274 good &= good << 1
275 good = uint8(int8(good) >> 7)
276
277
278
279
280
281
282
283
284
285
286 paddingLen &= good
287
288 toRemove = int(paddingLen) + 1
289 return
290 }
291
292
293
294
295 func extractPaddingSSL30(payload []byte) (toRemove int, good byte) {
296 if len(payload) < 1 {
297 return 0, 0
298 }
299
300 paddingLen := int(payload[len(payload)-1]) + 1
301 if paddingLen > len(payload) {
302 return 0, 0
303 }
304
305 return paddingLen, 255
306 }
307
308 func roundUp(a, b int) int {
309 return a + (b-a%b)%b
310 }
311
312
313 type cbcMode interface {
314 cipher.BlockMode
315 SetIV([]byte)
316 }
317
318
319
320 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
321 var plaintext []byte
322 typ := recordType(record[0])
323 payload := record[recordHeaderLen:]
324
325
326
327 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
328 return payload, typ, nil
329 }
330
331 paddingGood := byte(255)
332 paddingLen := 0
333
334 explicitNonceLen := hc.explicitNonceLen()
335
336 if hc.cipher != nil {
337 switch c := hc.cipher.(type) {
338 case cipher.Stream:
339 c.XORKeyStream(payload, payload)
340 case aead:
341 if len(payload) < explicitNonceLen {
342 return nil, 0, alertBadRecordMAC
343 }
344 nonce := payload[:explicitNonceLen]
345 if len(nonce) == 0 {
346 nonce = hc.seq[:]
347 }
348 payload = payload[explicitNonceLen:]
349
350 additionalData := hc.additionalData[:]
351 if hc.version == VersionTLS13 {
352 additionalData = record[:recordHeaderLen]
353 } else {
354 copy(additionalData, hc.seq[:])
355 copy(additionalData[8:], record[:3])
356 n := len(payload) - c.Overhead()
357 additionalData[11] = byte(n >> 8)
358 additionalData[12] = byte(n)
359 }
360
361 var err error
362 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
363 if err != nil {
364 return nil, 0, alertBadRecordMAC
365 }
366 case cbcMode:
367 blockSize := c.BlockSize()
368 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
369 if len(payload)%blockSize != 0 || len(payload) < minPayload {
370 return nil, 0, alertBadRecordMAC
371 }
372
373 if explicitNonceLen > 0 {
374 c.SetIV(payload[:explicitNonceLen])
375 payload = payload[explicitNonceLen:]
376 }
377 c.CryptBlocks(payload, payload)
378
379
380
381
382
383
384
385 if hc.version == VersionSSL30 {
386 paddingLen, paddingGood = extractPaddingSSL30(payload)
387 } else {
388 paddingLen, paddingGood = extractPadding(payload)
389 }
390 default:
391 panic("unknown cipher type")
392 }
393
394 if hc.version == VersionTLS13 {
395 if typ != recordTypeApplicationData {
396 return nil, 0, alertUnexpectedMessage
397 }
398 if len(plaintext) > maxPlaintext+1 {
399 return nil, 0, alertRecordOverflow
400 }
401
402 for i := len(plaintext) - 1; i >= 0; i-- {
403 if plaintext[i] != 0 {
404 typ = recordType(plaintext[i])
405 plaintext = plaintext[:i]
406 break
407 }
408 if i == 0 {
409 return nil, 0, alertUnexpectedMessage
410 }
411 }
412 }
413 } else {
414 plaintext = payload
415 }
416
417 if hc.mac != nil {
418 macSize := hc.mac.Size()
419 if len(payload) < macSize {
420 return nil, 0, alertBadRecordMAC
421 }
422
423 n := len(payload) - macSize - paddingLen
424 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
425 record[3] = byte(n >> 8)
426 record[4] = byte(n)
427 remoteMAC := payload[n : n+macSize]
428 localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
429
430
431
432
433
434
435
436
437 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
438 if macAndPaddingGood != 1 {
439 return nil, 0, alertBadRecordMAC
440 }
441
442 plaintext = payload[:n]
443 }
444
445 hc.incSeq()
446 return plaintext, typ, nil
447 }
448
449
450
451
452 func sliceForAppend(in []byte, n int) (head, tail []byte) {
453 if total := len(in) + n; cap(in) >= total {
454 head = in[:total]
455 } else {
456 head = make([]byte, total)
457 copy(head, in)
458 }
459 tail = head[len(in):]
460 return
461 }
462
463
464
465 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
466 if hc.cipher == nil {
467 return append(record, payload...), nil
468 }
469
470 var explicitNonce []byte
471 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
472 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
473 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
474
475
476
477
478
479
480
481
482
483 copy(explicitNonce, hc.seq[:])
484 } else {
485 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
486 return nil, err
487 }
488 }
489 }
490
491 var mac []byte
492 if hc.mac != nil {
493 mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
494 }
495
496 var dst []byte
497 switch c := hc.cipher.(type) {
498 case cipher.Stream:
499 record, dst = sliceForAppend(record, len(payload)+len(mac))
500 c.XORKeyStream(dst[:len(payload)], payload)
501 c.XORKeyStream(dst[len(payload):], mac)
502 case aead:
503 nonce := explicitNonce
504 if len(nonce) == 0 {
505 nonce = hc.seq[:]
506 }
507
508 if hc.version == VersionTLS13 {
509 record = append(record, payload...)
510
511
512 record = append(record, record[0])
513 record[0] = byte(recordTypeApplicationData)
514
515 n := len(payload) + 1 + c.Overhead()
516 record[3] = byte(n >> 8)
517 record[4] = byte(n)
518
519 record = c.Seal(record[:recordHeaderLen],
520 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
521 } else {
522 copy(hc.additionalData[:], hc.seq[:])
523 copy(hc.additionalData[8:], record)
524 record = c.Seal(record, nonce, payload, hc.additionalData[:])
525 }
526 case cbcMode:
527 blockSize := c.BlockSize()
528 plaintextLen := len(payload) + len(mac)
529 paddingLen := blockSize - plaintextLen%blockSize
530 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
531 copy(dst, payload)
532 copy(dst[len(payload):], mac)
533 for i := plaintextLen; i < len(dst); i++ {
534 dst[i] = byte(paddingLen - 1)
535 }
536 if len(explicitNonce) > 0 {
537 c.SetIV(explicitNonce)
538 }
539 c.CryptBlocks(dst, dst)
540 default:
541 panic("unknown cipher type")
542 }
543
544
545 n := len(record) - recordHeaderLen
546 record[3] = byte(n >> 8)
547 record[4] = byte(n)
548 hc.incSeq()
549
550 return record, nil
551 }
552
553
554 type RecordHeaderError struct {
555
556 Msg string
557
558
559 RecordHeader [5]byte
560
561
562
563
564 Conn net.Conn
565 }
566
567 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
568
569 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
570 err.Msg = msg
571 err.Conn = conn
572 copy(err.RecordHeader[:], c.rawInput.Bytes())
573 return err
574 }
575
576 func (c *Conn) readRecord() error {
577 return c.readRecordOrCCS(false)
578 }
579
580 func (c *Conn) readChangeCipherSpec() error {
581 return c.readRecordOrCCS(true)
582 }
583
584
585
586
587
588
589
590
591
592
593
594
595
596 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
597 if c.in.err != nil {
598 return c.in.err
599 }
600 handshakeComplete := c.handshakeComplete()
601
602
603 if c.input.Len() != 0 {
604 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
605 }
606 c.input.Reset(nil)
607
608
609 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
610
611
612
613 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
614 err = io.EOF
615 }
616 if e, ok := err.(net.Error); !ok || !e.Temporary() {
617 c.in.setErrorLocked(err)
618 }
619 return err
620 }
621 hdr := c.rawInput.Bytes()[:recordHeaderLen]
622 typ := recordType(hdr[0])
623
624
625
626
627
628 if !handshakeComplete && typ == 0x80 {
629 c.sendAlert(alertProtocolVersion)
630 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
631 }
632
633 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
634 n := int(hdr[3])<<8 | int(hdr[4])
635 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
636 c.sendAlert(alertProtocolVersion)
637 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
638 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
639 }
640 if !c.haveVers {
641
642
643
644
645 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
646 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
647 }
648 }
649 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
650 c.sendAlert(alertRecordOverflow)
651 msg := fmt.Sprintf("oversized record received with length %d", n)
652 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
653 }
654 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
655 if e, ok := err.(net.Error); !ok || !e.Temporary() {
656 c.in.setErrorLocked(err)
657 }
658 return err
659 }
660
661
662 record := c.rawInput.Next(recordHeaderLen + n)
663 data, typ, err := c.in.decrypt(record)
664 if err != nil {
665 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
666 }
667 if len(data) > maxPlaintext {
668 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
669 }
670
671
672 if c.in.cipher == nil && typ == recordTypeApplicationData {
673 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
674 }
675
676 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
677
678 c.retryCount = 0
679 }
680
681
682 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
683 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
684 }
685
686 switch typ {
687 default:
688 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
689
690 case recordTypeAlert:
691 if len(data) != 2 {
692 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
693 }
694 if alert(data[1]) == alertCloseNotify {
695 return c.in.setErrorLocked(io.EOF)
696 }
697 if c.vers == VersionTLS13 {
698 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
699 }
700 switch data[0] {
701 case alertLevelWarning:
702
703 return c.retryReadRecord(expectChangeCipherSpec)
704 case alertLevelError:
705 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
706 default:
707 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
708 }
709
710 case recordTypeChangeCipherSpec:
711 if len(data) != 1 || data[0] != 1 {
712 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
713 }
714
715 if c.hand.Len() > 0 {
716 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
717 }
718
719
720
721
722
723 if c.vers == VersionTLS13 {
724 return c.retryReadRecord(expectChangeCipherSpec)
725 }
726 if !expectChangeCipherSpec {
727 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
728 }
729 if err := c.in.changeCipherSpec(); err != nil {
730 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
731 }
732
733 case recordTypeApplicationData:
734 if !handshakeComplete || expectChangeCipherSpec {
735 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
736 }
737
738
739 if len(data) == 0 {
740 return c.retryReadRecord(expectChangeCipherSpec)
741 }
742
743
744
745 c.input.Reset(data)
746
747 case recordTypeHandshake:
748 if len(data) == 0 || expectChangeCipherSpec {
749 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
750 }
751 c.hand.Write(data)
752 }
753
754 return nil
755 }
756
757
758
759 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
760 c.retryCount++
761 if c.retryCount > maxUselessRecords {
762 c.sendAlert(alertUnexpectedMessage)
763 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
764 }
765 return c.readRecordOrCCS(expectChangeCipherSpec)
766 }
767
768
769
770
771 type atLeastReader struct {
772 R io.Reader
773 N int64
774 }
775
776 func (r *atLeastReader) Read(p []byte) (int, error) {
777 if r.N <= 0 {
778 return 0, io.EOF
779 }
780 n, err := r.R.Read(p)
781 r.N -= int64(n)
782 if r.N > 0 && err == io.EOF {
783 return n, io.ErrUnexpectedEOF
784 }
785 if r.N <= 0 && err == nil {
786 return n, io.EOF
787 }
788 return n, err
789 }
790
791
792
793 func (c *Conn) readFromUntil(r io.Reader, n int) error {
794 if c.rawInput.Len() >= n {
795 return nil
796 }
797 needs := n - c.rawInput.Len()
798
799
800
801 c.rawInput.Grow(needs + bytes.MinRead)
802 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
803 return err
804 }
805
806
807 func (c *Conn) sendAlertLocked(err alert) error {
808 switch err {
809 case alertNoRenegotiation, alertCloseNotify:
810 c.tmp[0] = alertLevelWarning
811 default:
812 c.tmp[0] = alertLevelError
813 }
814 c.tmp[1] = byte(err)
815
816 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
817 if err == alertCloseNotify {
818
819 return writeErr
820 }
821
822 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
823 }
824
825
826 func (c *Conn) sendAlert(err alert) error {
827 c.out.Lock()
828 defer c.out.Unlock()
829 return c.sendAlertLocked(err)
830 }
831
832 const (
833
834
835
836
837
838 tcpMSSEstimate = 1208
839
840
841
842
843 recordSizeBoostThreshold = 128 * 1024
844 )
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
863 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
864 return maxPlaintext
865 }
866
867 if c.bytesSent >= recordSizeBoostThreshold {
868 return maxPlaintext
869 }
870
871
872 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
873 if c.out.cipher != nil {
874 switch ciph := c.out.cipher.(type) {
875 case cipher.Stream:
876 payloadBytes -= c.out.mac.Size()
877 case cipher.AEAD:
878 payloadBytes -= ciph.Overhead()
879 case cbcMode:
880 blockSize := ciph.BlockSize()
881
882
883 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
884
885
886 payloadBytes -= c.out.mac.Size()
887 default:
888 panic("unknown cipher type")
889 }
890 }
891 if c.vers == VersionTLS13 {
892 payloadBytes--
893 }
894
895
896 pkt := c.packetsSent
897 c.packetsSent++
898 if pkt > 1000 {
899 return maxPlaintext
900 }
901
902 n := payloadBytes * int(pkt+1)
903 if n > maxPlaintext {
904 n = maxPlaintext
905 }
906 return n
907 }
908
909 func (c *Conn) write(data []byte) (int, error) {
910 if c.buffering {
911 c.sendBuf = append(c.sendBuf, data...)
912 return len(data), nil
913 }
914
915 n, err := c.conn.Write(data)
916 c.bytesSent += int64(n)
917 return n, err
918 }
919
920 func (c *Conn) flush() (int, error) {
921 if len(c.sendBuf) == 0 {
922 return 0, nil
923 }
924
925 n, err := c.conn.Write(c.sendBuf)
926 c.bytesSent += int64(n)
927 c.sendBuf = nil
928 c.buffering = false
929 return n, err
930 }
931
932
933
934 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
935 var n int
936 for len(data) > 0 {
937 m := len(data)
938 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
939 m = maxPayload
940 }
941
942 _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
943 c.outBuf[0] = byte(typ)
944 vers := c.vers
945 if vers == 0 {
946
947
948 vers = VersionTLS10
949 } else if vers == VersionTLS13 {
950
951
952 vers = VersionTLS12
953 }
954 c.outBuf[1] = byte(vers >> 8)
955 c.outBuf[2] = byte(vers)
956 c.outBuf[3] = byte(m >> 8)
957 c.outBuf[4] = byte(m)
958
959 var err error
960 c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
961 if err != nil {
962 return n, err
963 }
964 if _, err := c.write(c.outBuf); err != nil {
965 return n, err
966 }
967 n += m
968 data = data[m:]
969 }
970
971 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
972 if err := c.out.changeCipherSpec(); err != nil {
973 return n, c.sendAlertLocked(err.(alert))
974 }
975 }
976
977 return n, nil
978 }
979
980
981
982 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
983 c.out.Lock()
984 defer c.out.Unlock()
985
986 return c.writeRecordLocked(typ, data)
987 }
988
989
990
991 func (c *Conn) readHandshake() (interface{}, error) {
992 for c.hand.Len() < 4 {
993 if err := c.readRecord(); err != nil {
994 return nil, err
995 }
996 }
997
998 data := c.hand.Bytes()
999 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1000 if n > maxHandshake {
1001 c.sendAlertLocked(alertInternalError)
1002 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1003 }
1004 for c.hand.Len() < 4+n {
1005 if err := c.readRecord(); err != nil {
1006 return nil, err
1007 }
1008 }
1009 data = c.hand.Next(4 + n)
1010 var m handshakeMessage
1011 switch data[0] {
1012 case typeHelloRequest:
1013 m = new(helloRequestMsg)
1014 case typeClientHello:
1015 m = new(clientHelloMsg)
1016 case typeServerHello:
1017 m = new(serverHelloMsg)
1018 case typeNewSessionTicket:
1019 if c.vers == VersionTLS13 {
1020 m = new(newSessionTicketMsgTLS13)
1021 } else {
1022 m = new(newSessionTicketMsg)
1023 }
1024 case typeCertificate:
1025 if c.vers == VersionTLS13 {
1026 m = new(certificateMsgTLS13)
1027 } else {
1028 m = new(certificateMsg)
1029 }
1030 case typeCertificateRequest:
1031 if c.vers == VersionTLS13 {
1032 m = new(certificateRequestMsgTLS13)
1033 } else {
1034 m = &certificateRequestMsg{
1035 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1036 }
1037 }
1038 case typeCertificateStatus:
1039 m = new(certificateStatusMsg)
1040 case typeServerKeyExchange:
1041 m = new(serverKeyExchangeMsg)
1042 case typeServerHelloDone:
1043 m = new(serverHelloDoneMsg)
1044 case typeClientKeyExchange:
1045 m = new(clientKeyExchangeMsg)
1046 case typeCertificateVerify:
1047 m = &certificateVerifyMsg{
1048 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1049 }
1050 case typeNextProtocol:
1051 m = new(nextProtoMsg)
1052 case typeFinished:
1053 m = new(finishedMsg)
1054 case typeEncryptedExtensions:
1055 m = new(encryptedExtensionsMsg)
1056 case typeEndOfEarlyData:
1057 m = new(endOfEarlyDataMsg)
1058 case typeKeyUpdate:
1059 m = new(keyUpdateMsg)
1060 default:
1061 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1062 }
1063
1064
1065
1066
1067 data = append([]byte(nil), data...)
1068
1069 if !m.unmarshal(data) {
1070 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1071 }
1072 return m, nil
1073 }
1074
1075 var (
1076 errClosed = errors.New("tls: use of closed connection")
1077 errShutdown = errors.New("tls: protocol is shutdown")
1078 )
1079
1080
1081 func (c *Conn) Write(b []byte) (int, error) {
1082
1083 for {
1084 x := atomic.LoadInt32(&c.activeCall)
1085 if x&1 != 0 {
1086 return 0, errClosed
1087 }
1088 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1089 defer atomic.AddInt32(&c.activeCall, -2)
1090 break
1091 }
1092 }
1093
1094 if err := c.Handshake(); err != nil {
1095 return 0, err
1096 }
1097
1098 c.out.Lock()
1099 defer c.out.Unlock()
1100
1101 if err := c.out.err; err != nil {
1102 return 0, err
1103 }
1104
1105 if !c.handshakeComplete() {
1106 return 0, alertInternalError
1107 }
1108
1109 if c.closeNotifySent {
1110 return 0, errShutdown
1111 }
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122 var m int
1123 if len(b) > 1 && c.vers <= VersionTLS10 {
1124 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1125 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1126 if err != nil {
1127 return n, c.out.setErrorLocked(err)
1128 }
1129 m, b = 1, b[1:]
1130 }
1131 }
1132
1133 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1134 return n + m, c.out.setErrorLocked(err)
1135 }
1136
1137
1138 func (c *Conn) handleRenegotiation() error {
1139 if c.vers == VersionTLS13 {
1140 return errors.New("tls: internal error: unexpected renegotiation")
1141 }
1142
1143 msg, err := c.readHandshake()
1144 if err != nil {
1145 return err
1146 }
1147
1148 helloReq, ok := msg.(*helloRequestMsg)
1149 if !ok {
1150 c.sendAlert(alertUnexpectedMessage)
1151 return unexpectedMessageError(helloReq, msg)
1152 }
1153
1154 if !c.isClient {
1155 return c.sendAlert(alertNoRenegotiation)
1156 }
1157
1158 switch c.config.Renegotiation {
1159 case RenegotiateNever:
1160 return c.sendAlert(alertNoRenegotiation)
1161 case RenegotiateOnceAsClient:
1162 if c.handshakes > 1 {
1163 return c.sendAlert(alertNoRenegotiation)
1164 }
1165 case RenegotiateFreelyAsClient:
1166
1167 default:
1168 c.sendAlert(alertInternalError)
1169 return errors.New("tls: unknown Renegotiation value")
1170 }
1171
1172 c.handshakeMutex.Lock()
1173 defer c.handshakeMutex.Unlock()
1174
1175 atomic.StoreUint32(&c.handshakeStatus, 0)
1176 if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
1177 c.handshakes++
1178 }
1179 return c.handshakeErr
1180 }
1181
1182
1183
1184 func (c *Conn) handlePostHandshakeMessage() error {
1185 if c.vers != VersionTLS13 {
1186 return c.handleRenegotiation()
1187 }
1188
1189 msg, err := c.readHandshake()
1190 if err != nil {
1191 return err
1192 }
1193
1194 c.retryCount++
1195 if c.retryCount > maxUselessRecords {
1196 c.sendAlert(alertUnexpectedMessage)
1197 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1198 }
1199
1200 switch msg := msg.(type) {
1201 case *newSessionTicketMsgTLS13:
1202 return c.handleNewSessionTicket(msg)
1203 case *keyUpdateMsg:
1204 return c.handleKeyUpdate(msg)
1205 default:
1206 c.sendAlert(alertUnexpectedMessage)
1207 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1208 }
1209 }
1210
1211 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1212 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1213 if cipherSuite == nil {
1214 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1215 }
1216
1217 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1218 c.in.setTrafficSecret(cipherSuite, newSecret)
1219
1220 if keyUpdate.updateRequested {
1221 c.out.Lock()
1222 defer c.out.Unlock()
1223
1224 msg := &keyUpdateMsg{}
1225 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1226 if err != nil {
1227
1228 c.out.setErrorLocked(err)
1229 return nil
1230 }
1231
1232 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1233 c.out.setTrafficSecret(cipherSuite, newSecret)
1234 }
1235
1236 return nil
1237 }
1238
1239
1240
1241 func (c *Conn) Read(b []byte) (int, error) {
1242 if err := c.Handshake(); err != nil {
1243 return 0, err
1244 }
1245 if len(b) == 0 {
1246
1247
1248 return 0, nil
1249 }
1250
1251 c.in.Lock()
1252 defer c.in.Unlock()
1253
1254 for c.input.Len() == 0 {
1255 if err := c.readRecord(); err != nil {
1256 return 0, err
1257 }
1258 for c.hand.Len() > 0 {
1259 if err := c.handlePostHandshakeMessage(); err != nil {
1260 return 0, err
1261 }
1262 }
1263 }
1264
1265 n, _ := c.input.Read(b)
1266
1267
1268
1269
1270
1271
1272
1273
1274 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1275 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1276 if err := c.readRecord(); err != nil {
1277 return n, err
1278 }
1279 }
1280
1281 return n, nil
1282 }
1283
1284
1285 func (c *Conn) Close() error {
1286
1287 var x int32
1288 for {
1289 x = atomic.LoadInt32(&c.activeCall)
1290 if x&1 != 0 {
1291 return errClosed
1292 }
1293 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1294 break
1295 }
1296 }
1297 if x != 0 {
1298
1299
1300
1301
1302
1303
1304 return c.conn.Close()
1305 }
1306
1307 var alertErr error
1308
1309 if c.handshakeComplete() {
1310 alertErr = c.closeNotify()
1311 }
1312
1313 if err := c.conn.Close(); err != nil {
1314 return err
1315 }
1316 return alertErr
1317 }
1318
1319 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1320
1321
1322
1323
1324 func (c *Conn) CloseWrite() error {
1325 if !c.handshakeComplete() {
1326 return errEarlyCloseWrite
1327 }
1328
1329 return c.closeNotify()
1330 }
1331
1332 func (c *Conn) closeNotify() error {
1333 c.out.Lock()
1334 defer c.out.Unlock()
1335
1336 if !c.closeNotifySent {
1337 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1338 c.closeNotifySent = true
1339 }
1340 return c.closeNotifyErr
1341 }
1342
1343
1344
1345
1346
1347 func (c *Conn) Handshake() error {
1348 c.handshakeMutex.Lock()
1349 defer c.handshakeMutex.Unlock()
1350
1351 if err := c.handshakeErr; err != nil {
1352 return err
1353 }
1354 if c.handshakeComplete() {
1355 return nil
1356 }
1357
1358 c.in.Lock()
1359 defer c.in.Unlock()
1360
1361 if c.isClient {
1362 c.handshakeErr = c.clientHandshake()
1363 } else {
1364 c.handshakeErr = c.serverHandshake()
1365 }
1366 if c.handshakeErr == nil {
1367 c.handshakes++
1368 } else {
1369
1370
1371 c.flush()
1372 }
1373
1374 if c.handshakeErr == nil && !c.handshakeComplete() {
1375 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1376 }
1377
1378 return c.handshakeErr
1379 }
1380
1381
1382 func (c *Conn) ConnectionState() ConnectionState {
1383 c.handshakeMutex.Lock()
1384 defer c.handshakeMutex.Unlock()
1385
1386 var state ConnectionState
1387 state.HandshakeComplete = c.handshakeComplete()
1388 state.ServerName = c.serverName
1389
1390 if state.HandshakeComplete {
1391 state.Version = c.vers
1392 state.NegotiatedProtocol = c.clientProtocol
1393 state.DidResume = c.didResume
1394 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
1395 state.CipherSuite = c.cipherSuite
1396 state.PeerCertificates = c.peerCertificates
1397 state.VerifiedChains = c.verifiedChains
1398 state.SignedCertificateTimestamps = c.scts
1399 state.OCSPResponse = c.ocspResponse
1400 if !c.didResume && c.vers != VersionTLS13 {
1401 if c.clientFinishedIsFirst {
1402 state.TLSUnique = c.clientFinished[:]
1403 } else {
1404 state.TLSUnique = c.serverFinished[:]
1405 }
1406 }
1407 if c.config.Renegotiation != RenegotiateNever {
1408 state.ekm = noExportedKeyingMaterial
1409 } else {
1410 state.ekm = c.ekm
1411 }
1412 }
1413
1414 return state
1415 }
1416
1417
1418
1419 func (c *Conn) OCSPResponse() []byte {
1420 c.handshakeMutex.Lock()
1421 defer c.handshakeMutex.Unlock()
1422
1423 return c.ocspResponse
1424 }
1425
1426
1427
1428
1429 func (c *Conn) VerifyHostname(host string) error {
1430 c.handshakeMutex.Lock()
1431 defer c.handshakeMutex.Unlock()
1432 if !c.isClient {
1433 return errors.New("tls: VerifyHostname called on TLS server connection")
1434 }
1435 if !c.handshakeComplete() {
1436 return errors.New("tls: handshake has not yet been performed")
1437 }
1438 if len(c.verifiedChains) == 0 {
1439 return errors.New("tls: handshake did not verify certificate chain")
1440 }
1441 return c.peerCertificates[0].VerifyHostname(host)
1442 }
1443
1444 func (c *Conn) handshakeComplete() bool {
1445 return atomic.LoadUint32(&c.handshakeStatus) == 1
1446 }
1447
View as plain text