Source file src/pkg/database/sql/sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sql
17
18 import (
19 "context"
20 "database/sql/driver"
21 "errors"
22 "fmt"
23 "io"
24 "reflect"
25 "runtime"
26 "sort"
27 "strconv"
28 "sync"
29 "sync/atomic"
30 "time"
31 )
32
33 var (
34 driversMu sync.RWMutex
35 drivers = make(map[string]driver.Driver)
36 )
37
38
39 var nowFunc = time.Now
40
41
42
43
44 func Register(name string, driver driver.Driver) {
45 driversMu.Lock()
46 defer driversMu.Unlock()
47 if driver == nil {
48 panic("sql: Register driver is nil")
49 }
50 if _, dup := drivers[name]; dup {
51 panic("sql: Register called twice for driver " + name)
52 }
53 drivers[name] = driver
54 }
55
56 func unregisterAllDrivers() {
57 driversMu.Lock()
58 defer driversMu.Unlock()
59
60 drivers = make(map[string]driver.Driver)
61 }
62
63
64 func Drivers() []string {
65 driversMu.RLock()
66 defer driversMu.RUnlock()
67 var list []string
68 for name := range drivers {
69 list = append(list, name)
70 }
71 sort.Strings(list)
72 return list
73 }
74
75
76
77
78
79
80
81 type NamedArg struct {
82 _Named_Fields_Required struct{}
83
84
85
86
87
88
89
90 Name string
91
92
93
94
95 Value interface{}
96 }
97
98
99
100
101
102
103
104
105
106
107
108
109
110 func Named(name string, value interface{}) NamedArg {
111
112
113
114
115 return NamedArg{Name: name, Value: value}
116 }
117
118
119 type IsolationLevel int
120
121
122
123
124
125 const (
126 LevelDefault IsolationLevel = iota
127 LevelReadUncommitted
128 LevelReadCommitted
129 LevelWriteCommitted
130 LevelRepeatableRead
131 LevelSnapshot
132 LevelSerializable
133 LevelLinearizable
134 )
135
136
137 func (i IsolationLevel) String() string {
138 switch i {
139 case LevelDefault:
140 return "Default"
141 case LevelReadUncommitted:
142 return "Read Uncommitted"
143 case LevelReadCommitted:
144 return "Read Committed"
145 case LevelWriteCommitted:
146 return "Write Committed"
147 case LevelRepeatableRead:
148 return "Repeatable Read"
149 case LevelSnapshot:
150 return "Snapshot"
151 case LevelSerializable:
152 return "Serializable"
153 case LevelLinearizable:
154 return "Linearizable"
155 default:
156 return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
157 }
158 }
159
160 var _ fmt.Stringer = LevelDefault
161
162
163 type TxOptions struct {
164
165
166 Isolation IsolationLevel
167 ReadOnly bool
168 }
169
170
171
172
173 type RawBytes []byte
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188 type NullString struct {
189 String string
190 Valid bool
191 }
192
193
194 func (ns *NullString) Scan(value interface{}) error {
195 if value == nil {
196 ns.String, ns.Valid = "", false
197 return nil
198 }
199 ns.Valid = true
200 return convertAssign(&ns.String, value)
201 }
202
203
204 func (ns NullString) Value() (driver.Value, error) {
205 if !ns.Valid {
206 return nil, nil
207 }
208 return ns.String, nil
209 }
210
211
212
213
214 type NullInt64 struct {
215 Int64 int64
216 Valid bool
217 }
218
219
220 func (n *NullInt64) Scan(value interface{}) error {
221 if value == nil {
222 n.Int64, n.Valid = 0, false
223 return nil
224 }
225 n.Valid = true
226 return convertAssign(&n.Int64, value)
227 }
228
229
230 func (n NullInt64) Value() (driver.Value, error) {
231 if !n.Valid {
232 return nil, nil
233 }
234 return n.Int64, nil
235 }
236
237
238
239
240 type NullInt32 struct {
241 Int32 int32
242 Valid bool
243 }
244
245
246 func (n *NullInt32) Scan(value interface{}) error {
247 if value == nil {
248 n.Int32, n.Valid = 0, false
249 return nil
250 }
251 n.Valid = true
252 return convertAssign(&n.Int32, value)
253 }
254
255
256 func (n NullInt32) Value() (driver.Value, error) {
257 if !n.Valid {
258 return nil, nil
259 }
260 return int64(n.Int32), nil
261 }
262
263
264
265
266 type NullFloat64 struct {
267 Float64 float64
268 Valid bool
269 }
270
271
272 func (n *NullFloat64) Scan(value interface{}) error {
273 if value == nil {
274 n.Float64, n.Valid = 0, false
275 return nil
276 }
277 n.Valid = true
278 return convertAssign(&n.Float64, value)
279 }
280
281
282 func (n NullFloat64) Value() (driver.Value, error) {
283 if !n.Valid {
284 return nil, nil
285 }
286 return n.Float64, nil
287 }
288
289
290
291
292 type NullBool struct {
293 Bool bool
294 Valid bool
295 }
296
297
298 func (n *NullBool) Scan(value interface{}) error {
299 if value == nil {
300 n.Bool, n.Valid = false, false
301 return nil
302 }
303 n.Valid = true
304 return convertAssign(&n.Bool, value)
305 }
306
307
308 func (n NullBool) Value() (driver.Value, error) {
309 if !n.Valid {
310 return nil, nil
311 }
312 return n.Bool, nil
313 }
314
315
316
317
318 type NullTime struct {
319 Time time.Time
320 Valid bool
321 }
322
323
324 func (n *NullTime) Scan(value interface{}) error {
325 if value == nil {
326 n.Time, n.Valid = time.Time{}, false
327 return nil
328 }
329 n.Valid = true
330 return convertAssign(&n.Time, value)
331 }
332
333
334 func (n NullTime) Value() (driver.Value, error) {
335 if !n.Valid {
336 return nil, nil
337 }
338 return n.Time, nil
339 }
340
341
342 type Scanner interface {
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361 Scan(src interface{}) error
362 }
363
364
365
366
367
368
369
370
371
372 type Out struct {
373 _Named_Fields_Required struct{}
374
375
376
377 Dest interface{}
378
379
380
381
382 In bool
383 }
384
385
386
387
388 var ErrNoRows = errors.New("sql: no rows in result set")
389
390
391
392
393
394
395
396
397
398
399
400
401
402 type DB struct {
403
404
405 waitDuration int64
406
407 connector driver.Connector
408
409
410
411 numClosed uint64
412
413 mu sync.Mutex
414 freeConn []*driverConn
415 connRequests map[uint64]chan connRequest
416 nextRequest uint64
417 numOpen int
418
419
420
421
422
423 openerCh chan struct{}
424 resetterCh chan *driverConn
425 closed bool
426 dep map[finalCloser]depSet
427 lastPut map[*driverConn]string
428 maxIdle int
429 maxOpen int
430 maxLifetime time.Duration
431 cleanerCh chan struct{}
432 waitCount int64
433 maxIdleClosed int64
434 maxLifetimeClosed int64
435
436 stop func()
437 }
438
439
440 type connReuseStrategy uint8
441
442 const (
443
444 alwaysNewConn connReuseStrategy = iota
445
446
447
448 cachedOrNewConn
449 )
450
451
452
453
454
455 type driverConn struct {
456 db *DB
457 createdAt time.Time
458
459 sync.Mutex
460 ci driver.Conn
461 closed bool
462 finalClosed bool
463 openStmt map[*driverStmt]bool
464 lastErr error
465
466
467 inUse bool
468 onPut []func()
469 dbmuClosed bool
470 }
471
472 func (dc *driverConn) releaseConn(err error) {
473 dc.db.putConn(dc, err, true)
474 }
475
476 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
477 dc.Lock()
478 defer dc.Unlock()
479 delete(dc.openStmt, ds)
480 }
481
482 func (dc *driverConn) expired(timeout time.Duration) bool {
483 if timeout <= 0 {
484 return false
485 }
486 return dc.createdAt.Add(timeout).Before(nowFunc())
487 }
488
489
490
491 func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
492 si, err := ctxDriverPrepare(ctx, dc.ci, query)
493 if err != nil {
494 return nil, err
495 }
496 ds := &driverStmt{Locker: dc, si: si}
497
498
499 if cg != nil {
500 return ds, nil
501 }
502
503
504
505
506
507 if dc.openStmt == nil {
508 dc.openStmt = make(map[*driverStmt]bool)
509 }
510 dc.openStmt[ds] = true
511 return ds, nil
512 }
513
514
515
516
517
518
519 func (dc *driverConn) resetSession(ctx context.Context) {
520 defer dc.Unlock()
521 if dc.closed {
522 return
523 }
524 dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx)
525 }
526
527
528 func (dc *driverConn) closeDBLocked() func() error {
529 dc.Lock()
530 defer dc.Unlock()
531 if dc.closed {
532 return func() error { return errors.New("sql: duplicate driverConn close") }
533 }
534 dc.closed = true
535 return dc.db.removeDepLocked(dc, dc)
536 }
537
538 func (dc *driverConn) Close() error {
539 dc.Lock()
540 if dc.closed {
541 dc.Unlock()
542 return errors.New("sql: duplicate driverConn close")
543 }
544 dc.closed = true
545 dc.Unlock()
546
547
548 dc.db.mu.Lock()
549 dc.dbmuClosed = true
550 fn := dc.db.removeDepLocked(dc, dc)
551 dc.db.mu.Unlock()
552 return fn()
553 }
554
555 func (dc *driverConn) finalClose() error {
556 var err error
557
558
559
560 var openStmt []*driverStmt
561 withLock(dc, func() {
562 openStmt = make([]*driverStmt, 0, len(dc.openStmt))
563 for ds := range dc.openStmt {
564 openStmt = append(openStmt, ds)
565 }
566 dc.openStmt = nil
567 })
568 for _, ds := range openStmt {
569 ds.Close()
570 }
571 withLock(dc, func() {
572 dc.finalClosed = true
573 err = dc.ci.Close()
574 dc.ci = nil
575 })
576
577 dc.db.mu.Lock()
578 dc.db.numOpen--
579 dc.db.maybeOpenNewConnections()
580 dc.db.mu.Unlock()
581
582 atomic.AddUint64(&dc.db.numClosed, 1)
583 return err
584 }
585
586
587
588
589 type driverStmt struct {
590 sync.Locker
591 si driver.Stmt
592 closed bool
593 closeErr error
594 }
595
596
597
598 func (ds *driverStmt) Close() error {
599 ds.Lock()
600 defer ds.Unlock()
601 if ds.closed {
602 return ds.closeErr
603 }
604 ds.closed = true
605 ds.closeErr = ds.si.Close()
606 return ds.closeErr
607 }
608
609
610 type depSet map[interface{}]bool
611
612
613
614 type finalCloser interface {
615
616
617 finalClose() error
618 }
619
620
621
622 func (db *DB) addDep(x finalCloser, dep interface{}) {
623 db.mu.Lock()
624 defer db.mu.Unlock()
625 db.addDepLocked(x, dep)
626 }
627
628 func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
629 if db.dep == nil {
630 db.dep = make(map[finalCloser]depSet)
631 }
632 xdep := db.dep[x]
633 if xdep == nil {
634 xdep = make(depSet)
635 db.dep[x] = xdep
636 }
637 xdep[dep] = true
638 }
639
640
641
642
643
644 func (db *DB) removeDep(x finalCloser, dep interface{}) error {
645 db.mu.Lock()
646 fn := db.removeDepLocked(x, dep)
647 db.mu.Unlock()
648 return fn()
649 }
650
651 func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
652
653 xdep, ok := db.dep[x]
654 if !ok {
655 panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
656 }
657
658 l0 := len(xdep)
659 delete(xdep, dep)
660
661 switch len(xdep) {
662 case l0:
663
664 panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
665 case 0:
666
667 delete(db.dep, x)
668 return x.finalClose
669 default:
670
671 return func() error { return nil }
672 }
673 }
674
675
676
677
678
679
680 var connectionRequestQueueSize = 1000000
681
682 type dsnConnector struct {
683 dsn string
684 driver driver.Driver
685 }
686
687 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
688 return t.driver.Open(t.dsn)
689 }
690
691 func (t dsnConnector) Driver() driver.Driver {
692 return t.driver
693 }
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711 func OpenDB(c driver.Connector) *DB {
712 ctx, cancel := context.WithCancel(context.Background())
713 db := &DB{
714 connector: c,
715 openerCh: make(chan struct{}, connectionRequestQueueSize),
716 resetterCh: make(chan *driverConn, 50),
717 lastPut: make(map[*driverConn]string),
718 connRequests: make(map[uint64]chan connRequest),
719 stop: cancel,
720 }
721
722 go db.connectionOpener(ctx)
723 go db.connectionResetter(ctx)
724
725 return db
726 }
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745 func Open(driverName, dataSourceName string) (*DB, error) {
746 driversMu.RLock()
747 driveri, ok := drivers[driverName]
748 driversMu.RUnlock()
749 if !ok {
750 return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
751 }
752
753 if driverCtx, ok := driveri.(driver.DriverContext); ok {
754 connector, err := driverCtx.OpenConnector(dataSourceName)
755 if err != nil {
756 return nil, err
757 }
758 return OpenDB(connector), nil
759 }
760
761 return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
762 }
763
764 func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
765 var err error
766 if pinger, ok := dc.ci.(driver.Pinger); ok {
767 withLock(dc, func() {
768 err = pinger.Ping(ctx)
769 })
770 }
771 release(err)
772 return err
773 }
774
775
776
777 func (db *DB) PingContext(ctx context.Context) error {
778 var dc *driverConn
779 var err error
780
781 for i := 0; i < maxBadConnRetries; i++ {
782 dc, err = db.conn(ctx, cachedOrNewConn)
783 if err != driver.ErrBadConn {
784 break
785 }
786 }
787 if err == driver.ErrBadConn {
788 dc, err = db.conn(ctx, alwaysNewConn)
789 }
790 if err != nil {
791 return err
792 }
793
794 return db.pingDC(ctx, dc, dc.releaseConn)
795 }
796
797
798
799 func (db *DB) Ping() error {
800 return db.PingContext(context.Background())
801 }
802
803
804
805
806
807
808
809 func (db *DB) Close() error {
810 db.mu.Lock()
811 if db.closed {
812 db.mu.Unlock()
813 return nil
814 }
815 if db.cleanerCh != nil {
816 close(db.cleanerCh)
817 }
818 var err error
819 fns := make([]func() error, 0, len(db.freeConn))
820 for _, dc := range db.freeConn {
821 fns = append(fns, dc.closeDBLocked())
822 }
823 db.freeConn = nil
824 db.closed = true
825 for _, req := range db.connRequests {
826 close(req)
827 }
828 db.mu.Unlock()
829 for _, fn := range fns {
830 err1 := fn()
831 if err1 != nil {
832 err = err1
833 }
834 }
835 db.stop()
836 return err
837 }
838
839 const defaultMaxIdleConns = 2
840
841 func (db *DB) maxIdleConnsLocked() int {
842 n := db.maxIdle
843 switch {
844 case n == 0:
845
846 return defaultMaxIdleConns
847 case n < 0:
848 return 0
849 default:
850 return n
851 }
852 }
853
854
855
856
857
858
859
860
861
862
863
864 func (db *DB) SetMaxIdleConns(n int) {
865 db.mu.Lock()
866 if n > 0 {
867 db.maxIdle = n
868 } else {
869
870 db.maxIdle = -1
871 }
872
873 if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
874 db.maxIdle = db.maxOpen
875 }
876 var closing []*driverConn
877 idleCount := len(db.freeConn)
878 maxIdle := db.maxIdleConnsLocked()
879 if idleCount > maxIdle {
880 closing = db.freeConn[maxIdle:]
881 db.freeConn = db.freeConn[:maxIdle]
882 }
883 db.maxIdleClosed += int64(len(closing))
884 db.mu.Unlock()
885 for _, c := range closing {
886 c.Close()
887 }
888 }
889
890
891
892
893
894
895
896
897
898 func (db *DB) SetMaxOpenConns(n int) {
899 db.mu.Lock()
900 db.maxOpen = n
901 if n < 0 {
902 db.maxOpen = 0
903 }
904 syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
905 db.mu.Unlock()
906 if syncMaxIdle {
907 db.SetMaxIdleConns(n)
908 }
909 }
910
911
912
913
914
915
916 func (db *DB) SetConnMaxLifetime(d time.Duration) {
917 if d < 0 {
918 d = 0
919 }
920 db.mu.Lock()
921
922 if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
923 select {
924 case db.cleanerCh <- struct{}{}:
925 default:
926 }
927 }
928 db.maxLifetime = d
929 db.startCleanerLocked()
930 db.mu.Unlock()
931 }
932
933
934 func (db *DB) startCleanerLocked() {
935 if db.maxLifetime > 0 && db.numOpen > 0 && db.cleanerCh == nil {
936 db.cleanerCh = make(chan struct{}, 1)
937 go db.connectionCleaner(db.maxLifetime)
938 }
939 }
940
941 func (db *DB) connectionCleaner(d time.Duration) {
942 const minInterval = time.Second
943
944 if d < minInterval {
945 d = minInterval
946 }
947 t := time.NewTimer(d)
948
949 for {
950 select {
951 case <-t.C:
952 case <-db.cleanerCh:
953 }
954
955 db.mu.Lock()
956 d = db.maxLifetime
957 if db.closed || db.numOpen == 0 || d <= 0 {
958 db.cleanerCh = nil
959 db.mu.Unlock()
960 return
961 }
962
963 expiredSince := nowFunc().Add(-d)
964 var closing []*driverConn
965 for i := 0; i < len(db.freeConn); i++ {
966 c := db.freeConn[i]
967 if c.createdAt.Before(expiredSince) {
968 closing = append(closing, c)
969 last := len(db.freeConn) - 1
970 db.freeConn[i] = db.freeConn[last]
971 db.freeConn[last] = nil
972 db.freeConn = db.freeConn[:last]
973 i--
974 }
975 }
976 db.maxLifetimeClosed += int64(len(closing))
977 db.mu.Unlock()
978
979 for _, c := range closing {
980 c.Close()
981 }
982
983 if d < minInterval {
984 d = minInterval
985 }
986 t.Reset(d)
987 }
988 }
989
990
991 type DBStats struct {
992 MaxOpenConnections int
993
994
995 OpenConnections int
996 InUse int
997 Idle int
998
999
1000 WaitCount int64
1001 WaitDuration time.Duration
1002 MaxIdleClosed int64
1003 MaxLifetimeClosed int64
1004 }
1005
1006
1007 func (db *DB) Stats() DBStats {
1008 wait := atomic.LoadInt64(&db.waitDuration)
1009
1010 db.mu.Lock()
1011 defer db.mu.Unlock()
1012
1013 stats := DBStats{
1014 MaxOpenConnections: db.maxOpen,
1015
1016 Idle: len(db.freeConn),
1017 OpenConnections: db.numOpen,
1018 InUse: db.numOpen - len(db.freeConn),
1019
1020 WaitCount: db.waitCount,
1021 WaitDuration: time.Duration(wait),
1022 MaxIdleClosed: db.maxIdleClosed,
1023 MaxLifetimeClosed: db.maxLifetimeClosed,
1024 }
1025 return stats
1026 }
1027
1028
1029
1030
1031 func (db *DB) maybeOpenNewConnections() {
1032 numRequests := len(db.connRequests)
1033 if db.maxOpen > 0 {
1034 numCanOpen := db.maxOpen - db.numOpen
1035 if numRequests > numCanOpen {
1036 numRequests = numCanOpen
1037 }
1038 }
1039 for numRequests > 0 {
1040 db.numOpen++
1041 numRequests--
1042 if db.closed {
1043 return
1044 }
1045 db.openerCh <- struct{}{}
1046 }
1047 }
1048
1049
1050 func (db *DB) connectionOpener(ctx context.Context) {
1051 for {
1052 select {
1053 case <-ctx.Done():
1054 return
1055 case <-db.openerCh:
1056 db.openNewConnection(ctx)
1057 }
1058 }
1059 }
1060
1061
1062
1063 func (db *DB) connectionResetter(ctx context.Context) {
1064 for {
1065 select {
1066 case <-ctx.Done():
1067 close(db.resetterCh)
1068 for dc := range db.resetterCh {
1069 dc.Unlock()
1070 }
1071 return
1072 case dc := <-db.resetterCh:
1073 dc.resetSession(ctx)
1074 }
1075 }
1076 }
1077
1078
1079 func (db *DB) openNewConnection(ctx context.Context) {
1080
1081
1082
1083 ci, err := db.connector.Connect(ctx)
1084 db.mu.Lock()
1085 defer db.mu.Unlock()
1086 if db.closed {
1087 if err == nil {
1088 ci.Close()
1089 }
1090 db.numOpen--
1091 return
1092 }
1093 if err != nil {
1094 db.numOpen--
1095 db.putConnDBLocked(nil, err)
1096 db.maybeOpenNewConnections()
1097 return
1098 }
1099 dc := &driverConn{
1100 db: db,
1101 createdAt: nowFunc(),
1102 ci: ci,
1103 }
1104 if db.putConnDBLocked(dc, err) {
1105 db.addDepLocked(dc, dc)
1106 } else {
1107 db.numOpen--
1108 ci.Close()
1109 }
1110 }
1111
1112
1113
1114
1115 type connRequest struct {
1116 conn *driverConn
1117 err error
1118 }
1119
1120 var errDBClosed = errors.New("sql: database is closed")
1121
1122
1123
1124 func (db *DB) nextRequestKeyLocked() uint64 {
1125 next := db.nextRequest
1126 db.nextRequest++
1127 return next
1128 }
1129
1130
1131 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
1132 db.mu.Lock()
1133 if db.closed {
1134 db.mu.Unlock()
1135 return nil, errDBClosed
1136 }
1137
1138 select {
1139 default:
1140 case <-ctx.Done():
1141 db.mu.Unlock()
1142 return nil, ctx.Err()
1143 }
1144 lifetime := db.maxLifetime
1145
1146
1147 numFree := len(db.freeConn)
1148 if strategy == cachedOrNewConn && numFree > 0 {
1149 conn := db.freeConn[0]
1150 copy(db.freeConn, db.freeConn[1:])
1151 db.freeConn = db.freeConn[:numFree-1]
1152 conn.inUse = true
1153 db.mu.Unlock()
1154 if conn.expired(lifetime) {
1155 conn.Close()
1156 return nil, driver.ErrBadConn
1157 }
1158
1159 conn.Lock()
1160 err := conn.lastErr
1161 conn.Unlock()
1162 if err == driver.ErrBadConn {
1163 conn.Close()
1164 return nil, driver.ErrBadConn
1165 }
1166 return conn, nil
1167 }
1168
1169
1170
1171 if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
1172
1173
1174 req := make(chan connRequest, 1)
1175 reqKey := db.nextRequestKeyLocked()
1176 db.connRequests[reqKey] = req
1177 db.waitCount++
1178 db.mu.Unlock()
1179
1180 waitStart := time.Now()
1181
1182
1183 select {
1184 case <-ctx.Done():
1185
1186
1187 db.mu.Lock()
1188 delete(db.connRequests, reqKey)
1189 db.mu.Unlock()
1190
1191 atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
1192
1193 select {
1194 default:
1195 case ret, ok := <-req:
1196 if ok && ret.conn != nil {
1197 db.putConn(ret.conn, ret.err, false)
1198 }
1199 }
1200 return nil, ctx.Err()
1201 case ret, ok := <-req:
1202 atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
1203
1204 if !ok {
1205 return nil, errDBClosed
1206 }
1207 if ret.err == nil && ret.conn.expired(lifetime) {
1208 ret.conn.Close()
1209 return nil, driver.ErrBadConn
1210 }
1211 if ret.conn == nil {
1212 return nil, ret.err
1213 }
1214
1215 ret.conn.Lock()
1216 err := ret.conn.lastErr
1217 ret.conn.Unlock()
1218 if err == driver.ErrBadConn {
1219 ret.conn.Close()
1220 return nil, driver.ErrBadConn
1221 }
1222 return ret.conn, ret.err
1223 }
1224 }
1225
1226 db.numOpen++
1227 db.mu.Unlock()
1228 ci, err := db.connector.Connect(ctx)
1229 if err != nil {
1230 db.mu.Lock()
1231 db.numOpen--
1232 db.maybeOpenNewConnections()
1233 db.mu.Unlock()
1234 return nil, err
1235 }
1236 db.mu.Lock()
1237 dc := &driverConn{
1238 db: db,
1239 createdAt: nowFunc(),
1240 ci: ci,
1241 inUse: true,
1242 }
1243 db.addDepLocked(dc, dc)
1244 db.mu.Unlock()
1245 return dc, nil
1246 }
1247
1248
1249 var putConnHook func(*DB, *driverConn)
1250
1251
1252
1253
1254 func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
1255 db.mu.Lock()
1256 defer db.mu.Unlock()
1257 if c.inUse {
1258 c.onPut = append(c.onPut, func() {
1259 ds.Close()
1260 })
1261 } else {
1262 c.Lock()
1263 fc := c.finalClosed
1264 c.Unlock()
1265 if !fc {
1266 ds.Close()
1267 }
1268 }
1269 }
1270
1271
1272
1273 const debugGetPut = false
1274
1275
1276
1277 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
1278 db.mu.Lock()
1279 if !dc.inUse {
1280 if debugGetPut {
1281 fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
1282 }
1283 panic("sql: connection returned that was never out")
1284 }
1285 if debugGetPut {
1286 db.lastPut[dc] = stack()
1287 }
1288 dc.inUse = false
1289
1290 for _, fn := range dc.onPut {
1291 fn()
1292 }
1293 dc.onPut = nil
1294
1295 if err == driver.ErrBadConn {
1296
1297
1298
1299
1300 db.maybeOpenNewConnections()
1301 db.mu.Unlock()
1302 dc.Close()
1303 return
1304 }
1305 if putConnHook != nil {
1306 putConnHook(db, dc)
1307 }
1308 if db.closed {
1309
1310
1311 resetSession = false
1312 }
1313 if resetSession {
1314 if _, resetSession = dc.ci.(driver.SessionResetter); resetSession {
1315
1316
1317
1318
1319 dc.Lock()
1320 }
1321 }
1322 added := db.putConnDBLocked(dc, nil)
1323 db.mu.Unlock()
1324
1325 if !added {
1326 if resetSession {
1327 dc.Unlock()
1328 }
1329 dc.Close()
1330 return
1331 }
1332 if !resetSession {
1333 return
1334 }
1335 select {
1336 default:
1337
1338
1339 dc.lastErr = driver.ErrBadConn
1340 dc.Unlock()
1341 case db.resetterCh <- dc:
1342 }
1343 }
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354 func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
1355 if db.closed {
1356 return false
1357 }
1358 if db.maxOpen > 0 && db.numOpen > db.maxOpen {
1359 return false
1360 }
1361 if c := len(db.connRequests); c > 0 {
1362 var req chan connRequest
1363 var reqKey uint64
1364 for reqKey, req = range db.connRequests {
1365 break
1366 }
1367 delete(db.connRequests, reqKey)
1368 if err == nil {
1369 dc.inUse = true
1370 }
1371 req <- connRequest{
1372 conn: dc,
1373 err: err,
1374 }
1375 return true
1376 } else if err == nil && !db.closed {
1377 if db.maxIdleConnsLocked() > len(db.freeConn) {
1378 db.freeConn = append(db.freeConn, dc)
1379 db.startCleanerLocked()
1380 return true
1381 }
1382 db.maxIdleClosed++
1383 }
1384 return false
1385 }
1386
1387
1388
1389
1390 const maxBadConnRetries = 2
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1401 var stmt *Stmt
1402 var err error
1403 for i := 0; i < maxBadConnRetries; i++ {
1404 stmt, err = db.prepare(ctx, query, cachedOrNewConn)
1405 if err != driver.ErrBadConn {
1406 break
1407 }
1408 }
1409 if err == driver.ErrBadConn {
1410 return db.prepare(ctx, query, alwaysNewConn)
1411 }
1412 return stmt, err
1413 }
1414
1415
1416
1417
1418
1419
1420 func (db *DB) Prepare(query string) (*Stmt, error) {
1421 return db.PrepareContext(context.Background(), query)
1422 }
1423
1424 func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
1425
1426
1427
1428
1429
1430
1431 dc, err := db.conn(ctx, strategy)
1432 if err != nil {
1433 return nil, err
1434 }
1435 return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
1436 }
1437
1438
1439
1440
1441 func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
1442 var ds *driverStmt
1443 var err error
1444 defer func() {
1445 release(err)
1446 }()
1447 withLock(dc, func() {
1448 ds, err = dc.prepareLocked(ctx, cg, query)
1449 })
1450 if err != nil {
1451 return nil, err
1452 }
1453 stmt := &Stmt{
1454 db: db,
1455 query: query,
1456 cg: cg,
1457 cgds: ds,
1458 }
1459
1460
1461
1462
1463 if cg == nil {
1464 stmt.css = []connStmt{{dc, ds}}
1465 stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
1466 db.addDep(stmt, stmt)
1467 }
1468 return stmt, nil
1469 }
1470
1471
1472
1473 func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
1474 var res Result
1475 var err error
1476 for i := 0; i < maxBadConnRetries; i++ {
1477 res, err = db.exec(ctx, query, args, cachedOrNewConn)
1478 if err != driver.ErrBadConn {
1479 break
1480 }
1481 }
1482 if err == driver.ErrBadConn {
1483 return db.exec(ctx, query, args, alwaysNewConn)
1484 }
1485 return res, err
1486 }
1487
1488
1489
1490 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
1491 return db.ExecContext(context.Background(), query, args...)
1492 }
1493
1494 func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (Result, error) {
1495 dc, err := db.conn(ctx, strategy)
1496 if err != nil {
1497 return nil, err
1498 }
1499 return db.execDC(ctx, dc, dc.releaseConn, query, args)
1500 }
1501
1502 func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []interface{}) (res Result, err error) {
1503 defer func() {
1504 release(err)
1505 }()
1506 execerCtx, ok := dc.ci.(driver.ExecerContext)
1507 var execer driver.Execer
1508 if !ok {
1509 execer, ok = dc.ci.(driver.Execer)
1510 }
1511 if ok {
1512 var nvdargs []driver.NamedValue
1513 var resi driver.Result
1514 withLock(dc, func() {
1515 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1516 if err != nil {
1517 return
1518 }
1519 resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
1520 })
1521 if err != driver.ErrSkip {
1522 if err != nil {
1523 return nil, err
1524 }
1525 return driverResult{dc, resi}, nil
1526 }
1527 }
1528
1529 var si driver.Stmt
1530 withLock(dc, func() {
1531 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1532 })
1533 if err != nil {
1534 return nil, err
1535 }
1536 ds := &driverStmt{Locker: dc, si: si}
1537 defer ds.Close()
1538 return resultFromStatement(ctx, dc.ci, ds, args...)
1539 }
1540
1541
1542
1543 func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
1544 var rows *Rows
1545 var err error
1546 for i := 0; i < maxBadConnRetries; i++ {
1547 rows, err = db.query(ctx, query, args, cachedOrNewConn)
1548 if err != driver.ErrBadConn {
1549 break
1550 }
1551 }
1552 if err == driver.ErrBadConn {
1553 return db.query(ctx, query, args, alwaysNewConn)
1554 }
1555 return rows, err
1556 }
1557
1558
1559
1560 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
1561 return db.QueryContext(context.Background(), query, args...)
1562 }
1563
1564 func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
1565 dc, err := db.conn(ctx, strategy)
1566 if err != nil {
1567 return nil, err
1568 }
1569
1570 return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
1571 }
1572
1573
1574
1575
1576
1577 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
1578 queryerCtx, ok := dc.ci.(driver.QueryerContext)
1579 var queryer driver.Queryer
1580 if !ok {
1581 queryer, ok = dc.ci.(driver.Queryer)
1582 }
1583 if ok {
1584 var nvdargs []driver.NamedValue
1585 var rowsi driver.Rows
1586 var err error
1587 withLock(dc, func() {
1588 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1589 if err != nil {
1590 return
1591 }
1592 rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
1593 })
1594 if err != driver.ErrSkip {
1595 if err != nil {
1596 releaseConn(err)
1597 return nil, err
1598 }
1599
1600
1601 rows := &Rows{
1602 dc: dc,
1603 releaseConn: releaseConn,
1604 rowsi: rowsi,
1605 }
1606 rows.initContextClose(ctx, txctx)
1607 return rows, nil
1608 }
1609 }
1610
1611 var si driver.Stmt
1612 var err error
1613 withLock(dc, func() {
1614 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1615 })
1616 if err != nil {
1617 releaseConn(err)
1618 return nil, err
1619 }
1620
1621 ds := &driverStmt{Locker: dc, si: si}
1622 rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
1623 if err != nil {
1624 ds.Close()
1625 releaseConn(err)
1626 return nil, err
1627 }
1628
1629
1630
1631 rows := &Rows{
1632 dc: dc,
1633 releaseConn: releaseConn,
1634 rowsi: rowsi,
1635 closeStmt: ds,
1636 }
1637 rows.initContextClose(ctx, txctx)
1638 return rows, nil
1639 }
1640
1641
1642
1643
1644
1645
1646
1647 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
1648 rows, err := db.QueryContext(ctx, query, args...)
1649 return &Row{rows: rows, err: err}
1650 }
1651
1652
1653
1654
1655
1656
1657
1658 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
1659 return db.QueryRowContext(context.Background(), query, args...)
1660 }
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1673 var tx *Tx
1674 var err error
1675 for i := 0; i < maxBadConnRetries; i++ {
1676 tx, err = db.begin(ctx, opts, cachedOrNewConn)
1677 if err != driver.ErrBadConn {
1678 break
1679 }
1680 }
1681 if err == driver.ErrBadConn {
1682 return db.begin(ctx, opts, alwaysNewConn)
1683 }
1684 return tx, err
1685 }
1686
1687
1688
1689 func (db *DB) Begin() (*Tx, error) {
1690 return db.BeginTx(context.Background(), nil)
1691 }
1692
1693 func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
1694 dc, err := db.conn(ctx, strategy)
1695 if err != nil {
1696 return nil, err
1697 }
1698 return db.beginDC(ctx, dc, dc.releaseConn, opts)
1699 }
1700
1701
1702 func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
1703 var txi driver.Tx
1704 withLock(dc, func() {
1705 txi, err = ctxDriverBegin(ctx, opts, dc.ci)
1706 })
1707 if err != nil {
1708 release(err)
1709 return nil, err
1710 }
1711
1712
1713
1714 ctx, cancel := context.WithCancel(ctx)
1715 tx = &Tx{
1716 db: db,
1717 dc: dc,
1718 releaseConn: release,
1719 txi: txi,
1720 cancel: cancel,
1721 ctx: ctx,
1722 }
1723 go tx.awaitDone()
1724 return tx, nil
1725 }
1726
1727
1728 func (db *DB) Driver() driver.Driver {
1729 return db.connector.Driver()
1730 }
1731
1732
1733
1734 var ErrConnDone = errors.New("sql: connection is already closed")
1735
1736
1737
1738
1739
1740
1741
1742
1743 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
1744 var dc *driverConn
1745 var err error
1746 for i := 0; i < maxBadConnRetries; i++ {
1747 dc, err = db.conn(ctx, cachedOrNewConn)
1748 if err != driver.ErrBadConn {
1749 break
1750 }
1751 }
1752 if err == driver.ErrBadConn {
1753 dc, err = db.conn(ctx, alwaysNewConn)
1754 }
1755 if err != nil {
1756 return nil, err
1757 }
1758
1759 conn := &Conn{
1760 db: db,
1761 dc: dc,
1762 }
1763 return conn, nil
1764 }
1765
1766 type releaseConn func(error)
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777 type Conn struct {
1778 db *DB
1779
1780
1781
1782
1783 closemu sync.RWMutex
1784
1785
1786
1787 dc *driverConn
1788
1789
1790
1791
1792 done int32
1793 }
1794
1795
1796
1797 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
1798 if atomic.LoadInt32(&c.done) != 0 {
1799 return nil, nil, ErrConnDone
1800 }
1801 c.closemu.RLock()
1802 return c.dc, c.closemuRUnlockCondReleaseConn, nil
1803 }
1804
1805
1806 func (c *Conn) PingContext(ctx context.Context) error {
1807 dc, release, err := c.grabConn(ctx)
1808 if err != nil {
1809 return err
1810 }
1811 return c.db.pingDC(ctx, dc, release)
1812 }
1813
1814
1815
1816 func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
1817 dc, release, err := c.grabConn(ctx)
1818 if err != nil {
1819 return nil, err
1820 }
1821 return c.db.execDC(ctx, dc, release, query, args)
1822 }
1823
1824
1825
1826 func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
1827 dc, release, err := c.grabConn(ctx)
1828 if err != nil {
1829 return nil, err
1830 }
1831 return c.db.queryDC(ctx, nil, dc, release, query, args)
1832 }
1833
1834
1835
1836
1837
1838
1839
1840 func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
1841 rows, err := c.QueryContext(ctx, query, args...)
1842 return &Row{rows: rows, err: err}
1843 }
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1854 dc, release, err := c.grabConn(ctx)
1855 if err != nil {
1856 return nil, err
1857 }
1858 return c.db.prepareDC(ctx, dc, release, c, query)
1859 }
1860
1861
1862
1863
1864
1865
1866 func (c *Conn) Raw(f func(driverConn interface{}) error) (err error) {
1867 var dc *driverConn
1868 var release releaseConn
1869
1870
1871 dc, release, err = c.grabConn(nil)
1872 if err != nil {
1873 return
1874 }
1875 fPanic := true
1876 dc.Mutex.Lock()
1877 defer func() {
1878 dc.Mutex.Unlock()
1879
1880
1881
1882
1883 if fPanic {
1884 err = driver.ErrBadConn
1885 }
1886 release(err)
1887 }()
1888 err = f(dc.ci)
1889 fPanic = false
1890
1891 return
1892 }
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1905 dc, release, err := c.grabConn(ctx)
1906 if err != nil {
1907 return nil, err
1908 }
1909 return c.db.beginDC(ctx, dc, release, opts)
1910 }
1911
1912
1913
1914 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
1915 c.closemu.RUnlock()
1916 if err == driver.ErrBadConn {
1917 c.close(err)
1918 }
1919 }
1920
1921 func (c *Conn) txCtx() context.Context {
1922 return nil
1923 }
1924
1925 func (c *Conn) close(err error) error {
1926 if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
1927 return ErrConnDone
1928 }
1929
1930
1931
1932 c.closemu.Lock()
1933 defer c.closemu.Unlock()
1934
1935 c.dc.releaseConn(err)
1936 c.dc = nil
1937 c.db = nil
1938 return err
1939 }
1940
1941
1942
1943
1944
1945
1946 func (c *Conn) Close() error {
1947 return c.close(nil)
1948 }
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960 type Tx struct {
1961 db *DB
1962
1963
1964
1965
1966 closemu sync.RWMutex
1967
1968
1969
1970 dc *driverConn
1971 txi driver.Tx
1972
1973
1974
1975 releaseConn func(error)
1976
1977
1978
1979
1980
1981 done int32
1982
1983
1984
1985 stmts struct {
1986 sync.Mutex
1987 v []*Stmt
1988 }
1989
1990
1991 cancel func()
1992
1993
1994 ctx context.Context
1995 }
1996
1997
1998
1999 func (tx *Tx) awaitDone() {
2000
2001
2002 <-tx.ctx.Done()
2003
2004
2005
2006
2007
2008 tx.rollback(true)
2009 }
2010
2011 func (tx *Tx) isDone() bool {
2012 return atomic.LoadInt32(&tx.done) != 0
2013 }
2014
2015
2016
2017 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
2018
2019
2020
2021 func (tx *Tx) close(err error) {
2022 tx.cancel()
2023
2024 tx.closemu.Lock()
2025 defer tx.closemu.Unlock()
2026
2027 tx.releaseConn(err)
2028 tx.dc = nil
2029 tx.txi = nil
2030 }
2031
2032
2033
2034 var hookTxGrabConn func()
2035
2036 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
2037 select {
2038 default:
2039 case <-ctx.Done():
2040 return nil, nil, ctx.Err()
2041 }
2042
2043
2044
2045 tx.closemu.RLock()
2046 if tx.isDone() {
2047 tx.closemu.RUnlock()
2048 return nil, nil, ErrTxDone
2049 }
2050 if hookTxGrabConn != nil {
2051 hookTxGrabConn()
2052 }
2053 return tx.dc, tx.closemuRUnlockRelease, nil
2054 }
2055
2056 func (tx *Tx) txCtx() context.Context {
2057 return tx.ctx
2058 }
2059
2060
2061
2062
2063
2064 func (tx *Tx) closemuRUnlockRelease(error) {
2065 tx.closemu.RUnlock()
2066 }
2067
2068
2069 func (tx *Tx) closePrepared() {
2070 tx.stmts.Lock()
2071 defer tx.stmts.Unlock()
2072 for _, stmt := range tx.stmts.v {
2073 stmt.Close()
2074 }
2075 }
2076
2077
2078 func (tx *Tx) Commit() error {
2079
2080
2081
2082 select {
2083 default:
2084 case <-tx.ctx.Done():
2085 if atomic.LoadInt32(&tx.done) == 1 {
2086 return ErrTxDone
2087 }
2088 return tx.ctx.Err()
2089 }
2090 if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
2091 return ErrTxDone
2092 }
2093 var err error
2094 withLock(tx.dc, func() {
2095 err = tx.txi.Commit()
2096 })
2097 if err != driver.ErrBadConn {
2098 tx.closePrepared()
2099 }
2100 tx.close(err)
2101 return err
2102 }
2103
2104
2105
2106 func (tx *Tx) rollback(discardConn bool) error {
2107 if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
2108 return ErrTxDone
2109 }
2110 var err error
2111 withLock(tx.dc, func() {
2112 err = tx.txi.Rollback()
2113 })
2114 if err != driver.ErrBadConn {
2115 tx.closePrepared()
2116 }
2117 if discardConn {
2118 err = driver.ErrBadConn
2119 }
2120 tx.close(err)
2121 return err
2122 }
2123
2124
2125 func (tx *Tx) Rollback() error {
2126 return tx.rollback(false)
2127 }
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2140 dc, release, err := tx.grabConn(ctx)
2141 if err != nil {
2142 return nil, err
2143 }
2144
2145 stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
2146 if err != nil {
2147 return nil, err
2148 }
2149 tx.stmts.Lock()
2150 tx.stmts.v = append(tx.stmts.v, stmt)
2151 tx.stmts.Unlock()
2152 return stmt, nil
2153 }
2154
2155
2156
2157
2158
2159
2160
2161 func (tx *Tx) Prepare(query string) (*Stmt, error) {
2162 return tx.PrepareContext(context.Background(), query)
2163 }
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2181 dc, release, err := tx.grabConn(ctx)
2182 if err != nil {
2183 return &Stmt{stickyErr: err}
2184 }
2185 defer release(nil)
2186
2187 if tx.db != stmt.db {
2188 return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
2189 }
2190 var si driver.Stmt
2191 var parentStmt *Stmt
2192 stmt.mu.Lock()
2193 if stmt.closed || stmt.cg != nil {
2194
2195
2196
2197
2198
2199
2200 stmt.mu.Unlock()
2201 withLock(dc, func() {
2202 si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
2203 })
2204 if err != nil {
2205 return &Stmt{stickyErr: err}
2206 }
2207 } else {
2208 stmt.removeClosedStmtLocked()
2209
2210
2211 for _, v := range stmt.css {
2212 if v.dc == dc {
2213 si = v.ds.si
2214 break
2215 }
2216 }
2217
2218 stmt.mu.Unlock()
2219
2220 if si == nil {
2221 var ds *driverStmt
2222 withLock(dc, func() {
2223 ds, err = stmt.prepareOnConnLocked(ctx, dc)
2224 })
2225 if err != nil {
2226 return &Stmt{stickyErr: err}
2227 }
2228 si = ds.si
2229 }
2230 parentStmt = stmt
2231 }
2232
2233 txs := &Stmt{
2234 db: tx.db,
2235 cg: tx,
2236 cgds: &driverStmt{
2237 Locker: dc,
2238 si: si,
2239 },
2240 parentStmt: parentStmt,
2241 query: stmt.query,
2242 }
2243 if parentStmt != nil {
2244 tx.db.addDep(parentStmt, txs)
2245 }
2246 tx.stmts.Lock()
2247 tx.stmts.v = append(tx.stmts.v, txs)
2248 tx.stmts.Unlock()
2249 return txs
2250 }
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
2265 return tx.StmtContext(context.Background(), stmt)
2266 }
2267
2268
2269
2270 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
2271 dc, release, err := tx.grabConn(ctx)
2272 if err != nil {
2273 return nil, err
2274 }
2275 return tx.db.execDC(ctx, dc, release, query, args)
2276 }
2277
2278
2279
2280 func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
2281 return tx.ExecContext(context.Background(), query, args...)
2282 }
2283
2284
2285 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
2286 dc, release, err := tx.grabConn(ctx)
2287 if err != nil {
2288 return nil, err
2289 }
2290
2291 return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
2292 }
2293
2294
2295 func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
2296 return tx.QueryContext(context.Background(), query, args...)
2297 }
2298
2299
2300
2301
2302
2303
2304
2305 func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
2306 rows, err := tx.QueryContext(ctx, query, args...)
2307 return &Row{rows: rows, err: err}
2308 }
2309
2310
2311
2312
2313
2314
2315
2316 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
2317 return tx.QueryRowContext(context.Background(), query, args...)
2318 }
2319
2320
2321 type connStmt struct {
2322 dc *driverConn
2323 ds *driverStmt
2324 }
2325
2326
2327
2328 type stmtConnGrabber interface {
2329
2330
2331 grabConn(context.Context) (*driverConn, releaseConn, error)
2332
2333
2334
2335
2336 txCtx() context.Context
2337 }
2338
2339 var (
2340 _ stmtConnGrabber = &Tx{}
2341 _ stmtConnGrabber = &Conn{}
2342 )
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353 type Stmt struct {
2354
2355 db *DB
2356 query string
2357 stickyErr error
2358
2359 closemu sync.RWMutex
2360
2361
2362
2363
2364
2365
2366 cg stmtConnGrabber
2367 cgds *driverStmt
2368
2369
2370
2371
2372
2373
2374
2375 parentStmt *Stmt
2376
2377 mu sync.Mutex
2378 closed bool
2379
2380
2381
2382
2383
2384 css []connStmt
2385
2386
2387
2388 lastNumClosed uint64
2389 }
2390
2391
2392
2393 func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
2394 s.closemu.RLock()
2395 defer s.closemu.RUnlock()
2396
2397 var res Result
2398 strategy := cachedOrNewConn
2399 for i := 0; i < maxBadConnRetries+1; i++ {
2400 if i == maxBadConnRetries {
2401 strategy = alwaysNewConn
2402 }
2403 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2404 if err != nil {
2405 if err == driver.ErrBadConn {
2406 continue
2407 }
2408 return nil, err
2409 }
2410
2411 res, err = resultFromStatement(ctx, dc.ci, ds, args...)
2412 releaseConn(err)
2413 if err != driver.ErrBadConn {
2414 return res, err
2415 }
2416 }
2417 return nil, driver.ErrBadConn
2418 }
2419
2420
2421
2422 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
2423 return s.ExecContext(context.Background(), args...)
2424 }
2425
2426 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
2427 ds.Lock()
2428 defer ds.Unlock()
2429
2430 dargs, err := driverArgsConnLocked(ci, ds, args)
2431 if err != nil {
2432 return nil, err
2433 }
2434
2435 resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
2436 if err != nil {
2437 return nil, err
2438 }
2439 return driverResult{ds.Locker, resi}, nil
2440 }
2441
2442
2443
2444
2445
2446 func (s *Stmt) removeClosedStmtLocked() {
2447 t := len(s.css)/2 + 1
2448 if t > 10 {
2449 t = 10
2450 }
2451 dbClosed := atomic.LoadUint64(&s.db.numClosed)
2452 if dbClosed-s.lastNumClosed < uint64(t) {
2453 return
2454 }
2455
2456 s.db.mu.Lock()
2457 for i := 0; i < len(s.css); i++ {
2458 if s.css[i].dc.dbmuClosed {
2459 s.css[i] = s.css[len(s.css)-1]
2460 s.css = s.css[:len(s.css)-1]
2461 i--
2462 }
2463 }
2464 s.db.mu.Unlock()
2465 s.lastNumClosed = dbClosed
2466 }
2467
2468
2469
2470
2471 func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
2472 if err = s.stickyErr; err != nil {
2473 return
2474 }
2475 s.mu.Lock()
2476 if s.closed {
2477 s.mu.Unlock()
2478 err = errors.New("sql: statement is closed")
2479 return
2480 }
2481
2482
2483
2484 if s.cg != nil {
2485 s.mu.Unlock()
2486 dc, releaseConn, err = s.cg.grabConn(ctx)
2487 if err != nil {
2488 return
2489 }
2490 return dc, releaseConn, s.cgds, nil
2491 }
2492
2493 s.removeClosedStmtLocked()
2494 s.mu.Unlock()
2495
2496 dc, err = s.db.conn(ctx, strategy)
2497 if err != nil {
2498 return nil, nil, nil, err
2499 }
2500
2501 s.mu.Lock()
2502 for _, v := range s.css {
2503 if v.dc == dc {
2504 s.mu.Unlock()
2505 return dc, dc.releaseConn, v.ds, nil
2506 }
2507 }
2508 s.mu.Unlock()
2509
2510
2511 withLock(dc, func() {
2512 ds, err = s.prepareOnConnLocked(ctx, dc)
2513 })
2514 if err != nil {
2515 dc.releaseConn(err)
2516 return nil, nil, nil, err
2517 }
2518
2519 return dc, dc.releaseConn, ds, nil
2520 }
2521
2522
2523
2524 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
2525 si, err := dc.prepareLocked(ctx, s.cg, s.query)
2526 if err != nil {
2527 return nil, err
2528 }
2529 cs := connStmt{dc, si}
2530 s.mu.Lock()
2531 s.css = append(s.css, cs)
2532 s.mu.Unlock()
2533 return cs.ds, nil
2534 }
2535
2536
2537
2538 func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
2539 s.closemu.RLock()
2540 defer s.closemu.RUnlock()
2541
2542 var rowsi driver.Rows
2543 strategy := cachedOrNewConn
2544 for i := 0; i < maxBadConnRetries+1; i++ {
2545 if i == maxBadConnRetries {
2546 strategy = alwaysNewConn
2547 }
2548 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2549 if err != nil {
2550 if err == driver.ErrBadConn {
2551 continue
2552 }
2553 return nil, err
2554 }
2555
2556 rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
2557 if err == nil {
2558
2559
2560 rows := &Rows{
2561 dc: dc,
2562 rowsi: rowsi,
2563
2564 }
2565
2566
2567 s.db.addDep(s, rows)
2568
2569
2570
2571 rows.releaseConn = func(err error) {
2572 releaseConn(err)
2573 s.db.removeDep(s, rows)
2574 }
2575 var txctx context.Context
2576 if s.cg != nil {
2577 txctx = s.cg.txCtx()
2578 }
2579 rows.initContextClose(ctx, txctx)
2580 return rows, nil
2581 }
2582
2583 releaseConn(err)
2584 if err != driver.ErrBadConn {
2585 return nil, err
2586 }
2587 }
2588 return nil, driver.ErrBadConn
2589 }
2590
2591
2592
2593 func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
2594 return s.QueryContext(context.Background(), args...)
2595 }
2596
2597 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
2598 ds.Lock()
2599 defer ds.Unlock()
2600 dargs, err := driverArgsConnLocked(ci, ds, args)
2601 if err != nil {
2602 return nil, err
2603 }
2604 return ctxDriverStmtQuery(ctx, ds.si, dargs)
2605 }
2606
2607
2608
2609
2610
2611
2612
2613 func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
2614 rows, err := s.QueryContext(ctx, args...)
2615 if err != nil {
2616 return &Row{err: err}
2617 }
2618 return &Row{rows: rows}
2619 }
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632 func (s *Stmt) QueryRow(args ...interface{}) *Row {
2633 return s.QueryRowContext(context.Background(), args...)
2634 }
2635
2636
2637 func (s *Stmt) Close() error {
2638 s.closemu.Lock()
2639 defer s.closemu.Unlock()
2640
2641 if s.stickyErr != nil {
2642 return s.stickyErr
2643 }
2644 s.mu.Lock()
2645 if s.closed {
2646 s.mu.Unlock()
2647 return nil
2648 }
2649 s.closed = true
2650 txds := s.cgds
2651 s.cgds = nil
2652
2653 s.mu.Unlock()
2654
2655 if s.cg == nil {
2656 return s.db.removeDep(s, s)
2657 }
2658
2659 if s.parentStmt != nil {
2660
2661
2662 return s.db.removeDep(s.parentStmt, s)
2663 }
2664 return txds.Close()
2665 }
2666
2667 func (s *Stmt) finalClose() error {
2668 s.mu.Lock()
2669 defer s.mu.Unlock()
2670 if s.css != nil {
2671 for _, v := range s.css {
2672 s.db.noteUnusedDriverStatement(v.dc, v.ds)
2673 v.dc.removeOpenStmt(v.ds)
2674 }
2675 s.css = nil
2676 }
2677 return nil
2678 }
2679
2680
2681
2682 type Rows struct {
2683 dc *driverConn
2684 releaseConn func(error)
2685 rowsi driver.Rows
2686 cancel func()
2687 closeStmt *driverStmt
2688
2689
2690
2691
2692
2693
2694 closemu sync.RWMutex
2695 closed bool
2696 lasterr error
2697
2698
2699
2700 lastcols []driver.Value
2701 }
2702
2703
2704
2705 func (rs *Rows) lasterrOrErrLocked(err error) error {
2706 if rs.lasterr != nil && rs.lasterr != io.EOF {
2707 return rs.lasterr
2708 }
2709 return err
2710 }
2711
2712 func (rs *Rows) initContextClose(ctx, txctx context.Context) {
2713 if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
2714 return
2715 }
2716 ctx, rs.cancel = context.WithCancel(ctx)
2717 go rs.awaitDone(ctx, txctx)
2718 }
2719
2720
2721
2722
2723
2724 func (rs *Rows) awaitDone(ctx, txctx context.Context) {
2725 var txctxDone <-chan struct{}
2726 if txctx != nil {
2727 txctxDone = txctx.Done()
2728 }
2729 select {
2730 case <-ctx.Done():
2731 case <-txctxDone:
2732 }
2733 rs.close(ctx.Err())
2734 }
2735
2736
2737
2738
2739
2740
2741
2742 func (rs *Rows) Next() bool {
2743 var doClose, ok bool
2744 withLock(rs.closemu.RLocker(), func() {
2745 doClose, ok = rs.nextLocked()
2746 })
2747 if doClose {
2748 rs.Close()
2749 }
2750 return ok
2751 }
2752
2753 func (rs *Rows) nextLocked() (doClose, ok bool) {
2754 if rs.closed {
2755 return false, false
2756 }
2757
2758
2759
2760 rs.dc.Lock()
2761 defer rs.dc.Unlock()
2762
2763 if rs.lastcols == nil {
2764 rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
2765 }
2766
2767 rs.lasterr = rs.rowsi.Next(rs.lastcols)
2768 if rs.lasterr != nil {
2769
2770 if rs.lasterr != io.EOF {
2771 return true, false
2772 }
2773 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
2774 if !ok {
2775 return true, false
2776 }
2777
2778
2779
2780 if !nextResultSet.HasNextResultSet() {
2781 doClose = true
2782 }
2783 return doClose, false
2784 }
2785 return false, true
2786 }
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796 func (rs *Rows) NextResultSet() bool {
2797 var doClose bool
2798 defer func() {
2799 if doClose {
2800 rs.Close()
2801 }
2802 }()
2803 rs.closemu.RLock()
2804 defer rs.closemu.RUnlock()
2805
2806 if rs.closed {
2807 return false
2808 }
2809
2810 rs.lastcols = nil
2811 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
2812 if !ok {
2813 doClose = true
2814 return false
2815 }
2816
2817
2818
2819 rs.dc.Lock()
2820 defer rs.dc.Unlock()
2821
2822 rs.lasterr = nextResultSet.NextResultSet()
2823 if rs.lasterr != nil {
2824 doClose = true
2825 return false
2826 }
2827 return true
2828 }
2829
2830
2831
2832 func (rs *Rows) Err() error {
2833 rs.closemu.RLock()
2834 defer rs.closemu.RUnlock()
2835 return rs.lasterrOrErrLocked(nil)
2836 }
2837
2838 var errRowsClosed = errors.New("sql: Rows are closed")
2839 var errNoRows = errors.New("sql: no Rows available")
2840
2841
2842
2843 func (rs *Rows) Columns() ([]string, error) {
2844 rs.closemu.RLock()
2845 defer rs.closemu.RUnlock()
2846 if rs.closed {
2847 return nil, rs.lasterrOrErrLocked(errRowsClosed)
2848 }
2849 if rs.rowsi == nil {
2850 return nil, rs.lasterrOrErrLocked(errNoRows)
2851 }
2852 rs.dc.Lock()
2853 defer rs.dc.Unlock()
2854
2855 return rs.rowsi.Columns(), nil
2856 }
2857
2858
2859
2860 func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
2861 rs.closemu.RLock()
2862 defer rs.closemu.RUnlock()
2863 if rs.closed {
2864 return nil, rs.lasterrOrErrLocked(errRowsClosed)
2865 }
2866 if rs.rowsi == nil {
2867 return nil, rs.lasterrOrErrLocked(errNoRows)
2868 }
2869 rs.dc.Lock()
2870 defer rs.dc.Unlock()
2871
2872 return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
2873 }
2874
2875
2876 type ColumnType struct {
2877 name string
2878
2879 hasNullable bool
2880 hasLength bool
2881 hasPrecisionScale bool
2882
2883 nullable bool
2884 length int64
2885 databaseType string
2886 precision int64
2887 scale int64
2888 scanType reflect.Type
2889 }
2890
2891
2892 func (ci *ColumnType) Name() string {
2893 return ci.name
2894 }
2895
2896
2897
2898
2899
2900
2901 func (ci *ColumnType) Length() (length int64, ok bool) {
2902 return ci.length, ci.hasLength
2903 }
2904
2905
2906
2907 func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
2908 return ci.precision, ci.scale, ci.hasPrecisionScale
2909 }
2910
2911
2912
2913
2914 func (ci *ColumnType) ScanType() reflect.Type {
2915 return ci.scanType
2916 }
2917
2918
2919
2920 func (ci *ColumnType) Nullable() (nullable, ok bool) {
2921 return ci.nullable, ci.hasNullable
2922 }
2923
2924
2925
2926
2927
2928
2929 func (ci *ColumnType) DatabaseTypeName() string {
2930 return ci.databaseType
2931 }
2932
2933 func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
2934 names := rowsi.Columns()
2935
2936 list := make([]*ColumnType, len(names))
2937 for i := range list {
2938 ci := &ColumnType{
2939 name: names[i],
2940 }
2941 list[i] = ci
2942
2943 if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
2944 ci.scanType = prop.ColumnTypeScanType(i)
2945 } else {
2946 ci.scanType = reflect.TypeOf(new(interface{})).Elem()
2947 }
2948 if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
2949 ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
2950 }
2951 if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
2952 ci.length, ci.hasLength = prop.ColumnTypeLength(i)
2953 }
2954 if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
2955 ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
2956 }
2957 if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
2958 ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
2959 }
2960 }
2961 return list
2962 }
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021 func (rs *Rows) Scan(dest ...interface{}) error {
3022 rs.closemu.RLock()
3023
3024 if rs.lasterr != nil && rs.lasterr != io.EOF {
3025 rs.closemu.RUnlock()
3026 return rs.lasterr
3027 }
3028 if rs.closed {
3029 err := rs.lasterrOrErrLocked(errRowsClosed)
3030 rs.closemu.RUnlock()
3031 return err
3032 }
3033 rs.closemu.RUnlock()
3034
3035 if rs.lastcols == nil {
3036 return errors.New("sql: Scan called without calling Next")
3037 }
3038 if len(dest) != len(rs.lastcols) {
3039 return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
3040 }
3041 for i, sv := range rs.lastcols {
3042 err := convertAssignRows(dest[i], sv, rs)
3043 if err != nil {
3044 return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err)
3045 }
3046 }
3047 return nil
3048 }
3049
3050
3051
3052 var rowsCloseHook = func() func(*Rows, *error) { return nil }
3053
3054
3055
3056
3057
3058 func (rs *Rows) Close() error {
3059 return rs.close(nil)
3060 }
3061
3062 func (rs *Rows) close(err error) error {
3063 rs.closemu.Lock()
3064 defer rs.closemu.Unlock()
3065
3066 if rs.closed {
3067 return nil
3068 }
3069 rs.closed = true
3070
3071 if rs.lasterr == nil {
3072 rs.lasterr = err
3073 }
3074
3075 withLock(rs.dc, func() {
3076 err = rs.rowsi.Close()
3077 })
3078 if fn := rowsCloseHook(); fn != nil {
3079 fn(rs, &err)
3080 }
3081 if rs.cancel != nil {
3082 rs.cancel()
3083 }
3084
3085 if rs.closeStmt != nil {
3086 rs.closeStmt.Close()
3087 }
3088 rs.releaseConn(err)
3089 return err
3090 }
3091
3092
3093 type Row struct {
3094
3095 err error
3096 rows *Rows
3097 }
3098
3099
3100
3101
3102
3103
3104 func (r *Row) Scan(dest ...interface{}) error {
3105 if r.err != nil {
3106 return r.err
3107 }
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122 defer r.rows.Close()
3123 for _, dp := range dest {
3124 if _, ok := dp.(*RawBytes); ok {
3125 return errors.New("sql: RawBytes isn't allowed on Row.Scan")
3126 }
3127 }
3128
3129 if !r.rows.Next() {
3130 if err := r.rows.Err(); err != nil {
3131 return err
3132 }
3133 return ErrNoRows
3134 }
3135 err := r.rows.Scan(dest...)
3136 if err != nil {
3137 return err
3138 }
3139
3140 return r.rows.Close()
3141 }
3142
3143
3144 type Result interface {
3145
3146
3147
3148
3149
3150 LastInsertId() (int64, error)
3151
3152
3153
3154
3155 RowsAffected() (int64, error)
3156 }
3157
3158 type driverResult struct {
3159 sync.Locker
3160 resi driver.Result
3161 }
3162
3163 func (dr driverResult) LastInsertId() (int64, error) {
3164 dr.Lock()
3165 defer dr.Unlock()
3166 return dr.resi.LastInsertId()
3167 }
3168
3169 func (dr driverResult) RowsAffected() (int64, error) {
3170 dr.Lock()
3171 defer dr.Unlock()
3172 return dr.resi.RowsAffected()
3173 }
3174
3175 func stack() string {
3176 var buf [2 << 10]byte
3177 return string(buf[:runtime.Stack(buf[:], false)])
3178 }
3179
3180
3181 func withLock(lk sync.Locker, fn func()) {
3182 lk.Lock()
3183 defer lk.Unlock()
3184 fn()
3185 }
3186
View as plain text