Source file src/cmd/fix/typecheck.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "io/ioutil"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "reflect"
17 "runtime"
18 "strings"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 func mkType(t string) string {
59 return "type " + t
60 }
61
62 func getType(t string) string {
63 if !isType(t) {
64 return ""
65 }
66 return t[len("type "):]
67 }
68
69 func isType(t string) bool {
70 return strings.HasPrefix(t, "type ")
71 }
72
73
74
75
76
77
78 type TypeConfig struct {
79 Type map[string]*Type
80 Var map[string]string
81 Func map[string]string
82
83
84
85
86 External map[string]string
87 }
88
89
90
91 func (cfg *TypeConfig) typeof(name string) string {
92 if cfg.Var != nil {
93 if t := cfg.Var[name]; t != "" {
94 return t
95 }
96 }
97 if cfg.Func != nil {
98 if t := cfg.Func[name]; t != "" {
99 return "func()" + t
100 }
101 }
102 return ""
103 }
104
105
106
107
108 type Type struct {
109 Field map[string]string
110 Method map[string]string
111 Embed []string
112 Def string
113 }
114
115
116
117 func (typ *Type) dot(cfg *TypeConfig, name string) string {
118 if typ.Field != nil {
119 if t := typ.Field[name]; t != "" {
120 return t
121 }
122 }
123 if typ.Method != nil {
124 if t := typ.Method[name]; t != "" {
125 return t
126 }
127 }
128
129 for _, e := range typ.Embed {
130 etyp := cfg.Type[e]
131 if etyp != nil {
132 if t := etyp.dot(cfg, name); t != "" {
133 return t
134 }
135 }
136 }
137
138 return ""
139 }
140
141
142
143
144
145
146 func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
147 typeof = make(map[interface{}]string)
148 assign = make(map[string][]interface{})
149 cfg1 := &TypeConfig{}
150 *cfg1 = *cfg
151 copied := false
152
153
154 cfg.External = map[string]string{}
155 cfg1.External = cfg.External
156 if imports(f, "C") {
157
158
159
160 err := func() error {
161 txt, err := gofmtFile(f)
162 if err != nil {
163 return err
164 }
165 dir, err := ioutil.TempDir(os.TempDir(), "fix_cgo_typecheck")
166 if err != nil {
167 return err
168 }
169 defer os.RemoveAll(dir)
170 err = ioutil.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
171 if err != nil {
172 return err
173 }
174 cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
175 err = cmd.Run()
176 if err != nil {
177 return err
178 }
179 out, err := ioutil.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
180 if err != nil {
181 return err
182 }
183 cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
184 if err != nil {
185 return err
186 }
187 for _, decl := range cgo.Decls {
188 fn, ok := decl.(*ast.FuncDecl)
189 if !ok {
190 continue
191 }
192 if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
193 var params, results []string
194 for _, p := range fn.Type.Params.List {
195 t := gofmt(p.Type)
196 t = strings.ReplaceAll(t, "_Ctype_", "C.")
197 params = append(params, t)
198 }
199 for _, r := range fn.Type.Results.List {
200 t := gofmt(r.Type)
201 t = strings.ReplaceAll(t, "_Ctype_", "C.")
202 results = append(results, t)
203 }
204 cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
205 }
206 }
207 return nil
208 }()
209 if err != nil {
210 fmt.Printf("warning: no cgo types: %s\n", err)
211 }
212 }
213
214
215 for _, decl := range f.Decls {
216 fn, ok := decl.(*ast.FuncDecl)
217 if !ok {
218 continue
219 }
220 typecheck1(cfg, fn.Type, typeof, assign)
221 t := typeof[fn.Type]
222 if fn.Recv != nil {
223
224 rcvr := typeof[fn.Recv]
225 if !isType(rcvr) {
226 if len(fn.Recv.List) != 1 {
227 continue
228 }
229 rcvr = mkType(gofmt(fn.Recv.List[0].Type))
230 typeof[fn.Recv.List[0].Type] = rcvr
231 }
232 rcvr = getType(rcvr)
233 if rcvr != "" && rcvr[0] == '*' {
234 rcvr = rcvr[1:]
235 }
236 typeof[rcvr+"."+fn.Name.Name] = t
237 } else {
238 if isType(t) {
239 t = getType(t)
240 } else {
241 t = gofmt(fn.Type)
242 }
243 typeof[fn.Name] = t
244
245
246 typeof[fn.Name.Obj] = t
247 }
248 }
249
250
251 for _, decl := range f.Decls {
252 d, ok := decl.(*ast.GenDecl)
253 if ok {
254 for _, s := range d.Specs {
255 switch s := s.(type) {
256 case *ast.TypeSpec:
257 if cfg1.Type[s.Name.Name] != nil {
258 break
259 }
260 if !copied {
261 copied = true
262
263 cfg1.Type = make(map[string]*Type)
264 for k, v := range cfg.Type {
265 cfg1.Type[k] = v
266 }
267 }
268 t := &Type{Field: map[string]string{}}
269 cfg1.Type[s.Name.Name] = t
270 switch st := s.Type.(type) {
271 case *ast.StructType:
272 for _, f := range st.Fields.List {
273 for _, n := range f.Names {
274 t.Field[n.Name] = gofmt(f.Type)
275 }
276 }
277 case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
278 t.Def = gofmt(st)
279 }
280 }
281 }
282 }
283 }
284
285 typecheck1(cfg1, f, typeof, assign)
286 return typeof, assign
287 }
288
289 func makeExprList(a []*ast.Ident) []ast.Expr {
290 var b []ast.Expr
291 for _, x := range a {
292 b = append(b, x)
293 }
294 return b
295 }
296
297
298
299
300 func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
301
302
303 set := func(n ast.Expr, typ string, isDecl bool) {
304 if typeof[n] != "" || typ == "" {
305 if typeof[n] != typ {
306 assign[typ] = append(assign[typ], n)
307 }
308 return
309 }
310 typeof[n] = typ
311
312
313
314
315
316
317
318 if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
319 typeof[id.Obj] = typ
320 }
321 }
322
323
324
325
326 typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
327 if len(lhs) > 1 && len(rhs) == 1 {
328 if _, ok := rhs[0].(*ast.CallExpr); ok {
329 t := split(typeof[rhs[0]])
330
331 for i := 0; i < len(lhs) && i < len(t); i++ {
332 set(lhs[i], t[i], isDecl)
333 }
334 return
335 }
336 }
337 if len(lhs) == 1 && len(rhs) == 2 {
338
339 rhs = rhs[:1]
340 } else if len(lhs) == 2 && len(rhs) == 1 {
341
342 lhs = lhs[:1]
343 }
344
345
346 for i := 0; i < len(lhs) && i < len(rhs); i++ {
347 x, y := lhs[i], rhs[i]
348 if typeof[y] != "" {
349 set(x, typeof[y], isDecl)
350 } else {
351 set(y, typeof[x], false)
352 }
353 }
354 }
355
356 expand := func(s string) string {
357 typ := cfg.Type[s]
358 if typ != nil && typ.Def != "" {
359 return typ.Def
360 }
361 return s
362 }
363
364
365
366
367
368
369
370 var curfn []*ast.FuncType
371
372 before := func(n interface{}) {
373
374 switch n := n.(type) {
375 case *ast.FuncDecl:
376 curfn = append(curfn, n.Type)
377 case *ast.FuncLit:
378 curfn = append(curfn, n.Type)
379 }
380 }
381
382
383 after := func(n interface{}) {
384 if n == nil {
385 return
386 }
387 if false && reflect.TypeOf(n).Kind() == reflect.Ptr {
388 defer func() {
389 if t := typeof[n]; t != "" {
390 pos := fset.Position(n.(ast.Node).Pos())
391 fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
392 }
393 }()
394 }
395
396 switch n := n.(type) {
397 case *ast.FuncDecl, *ast.FuncLit:
398
399 curfn = curfn[:len(curfn)-1]
400
401 case *ast.FuncType:
402 typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
403
404 case *ast.FieldList:
405
406 t := ""
407 for _, field := range n.List {
408 if t != "" {
409 t += ", "
410 }
411 t += typeof[field]
412 }
413 typeof[n] = t
414
415 case *ast.Field:
416
417 all := ""
418 t := typeof[n.Type]
419 if !isType(t) {
420
421
422 t = mkType(gofmt(n.Type))
423 typeof[n.Type] = t
424 }
425 t = getType(t)
426 if len(n.Names) == 0 {
427 all = t
428 } else {
429 for _, id := range n.Names {
430 if all != "" {
431 all += ", "
432 }
433 all += t
434 typeof[id.Obj] = t
435 typeof[id] = t
436 }
437 }
438 typeof[n] = all
439
440 case *ast.ValueSpec:
441
442 if n.Type != nil {
443 t := typeof[n.Type]
444 if !isType(t) {
445 t = mkType(gofmt(n.Type))
446 typeof[n.Type] = t
447 }
448 t = getType(t)
449 for _, id := range n.Names {
450 set(id, t, true)
451 }
452 }
453
454 typecheckAssign(makeExprList(n.Names), n.Values, true)
455
456 case *ast.AssignStmt:
457 typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
458
459 case *ast.Ident:
460
461 if t := typeof[n.Obj]; t != "" {
462 typeof[n] = t
463 }
464
465 case *ast.SelectorExpr:
466
467 name := n.Sel.Name
468 if t := typeof[n.X]; t != "" {
469 t = strings.TrimPrefix(t, "*")
470 if typ := cfg.Type[t]; typ != nil {
471 if t := typ.dot(cfg, name); t != "" {
472 typeof[n] = t
473 return
474 }
475 }
476 tt := typeof[t+"."+name]
477 if isType(tt) {
478 typeof[n] = getType(tt)
479 return
480 }
481 }
482
483 if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
484 str := x.Name + "." + name
485 if cfg.Type[str] != nil {
486 typeof[n] = mkType(str)
487 return
488 }
489 if t := cfg.typeof(x.Name + "." + name); t != "" {
490 typeof[n] = t
491 return
492 }
493 }
494
495 case *ast.CallExpr:
496
497 if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
498 typeof[n] = gofmt(n.Args[0])
499 return
500 }
501
502 if isTopName(n.Fun, "new") && len(n.Args) == 1 {
503 typeof[n] = "*" + gofmt(n.Args[0])
504 return
505 }
506
507 t := typeof[n.Fun]
508 if t == "" {
509 t = cfg.External[gofmt(n.Fun)]
510 }
511 in, out := splitFunc(t)
512 if in == nil && out == nil {
513 return
514 }
515 typeof[n] = join(out)
516 for i, arg := range n.Args {
517 if i >= len(in) {
518 break
519 }
520 if typeof[arg] == "" {
521 typeof[arg] = in[i]
522 }
523 }
524
525 case *ast.TypeAssertExpr:
526
527 if n.Type == nil {
528 typeof[n] = typeof[n.X]
529 return
530 }
531
532 if t := typeof[n.Type]; isType(t) {
533 typeof[n] = getType(t)
534 } else {
535 typeof[n] = gofmt(n.Type)
536 }
537
538 case *ast.SliceExpr:
539
540 typeof[n] = typeof[n.X]
541
542 case *ast.IndexExpr:
543
544 t := expand(typeof[n.X])
545 if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
546
547
548 if i := strings.Index(t, "]"); i >= 0 {
549 typeof[n] = t[i+1:]
550 }
551 }
552
553 case *ast.StarExpr:
554
555
556
557 t := expand(typeof[n.X])
558 if isType(t) {
559 typeof[n] = "type *" + getType(t)
560 } else if strings.HasPrefix(t, "*") {
561 typeof[n] = t[len("*"):]
562 }
563
564 case *ast.UnaryExpr:
565
566 t := typeof[n.X]
567 if t != "" && n.Op == token.AND {
568 typeof[n] = "*" + t
569 }
570
571 case *ast.CompositeLit:
572
573 typeof[n] = gofmt(n.Type)
574
575
576 t := expand(typeof[n])
577 if strings.HasPrefix(t, "[") {
578
579 if i := strings.Index(t, "]"); i >= 0 {
580 et := t[i+1:]
581 for _, e := range n.Elts {
582 if kv, ok := e.(*ast.KeyValueExpr); ok {
583 e = kv.Value
584 }
585 if typeof[e] == "" {
586 typeof[e] = et
587 }
588 }
589 }
590 }
591 if strings.HasPrefix(t, "map[") {
592
593 if i := strings.Index(t, "]"); i >= 0 {
594 kt, vt := t[4:i], t[i+1:]
595 for _, e := range n.Elts {
596 if kv, ok := e.(*ast.KeyValueExpr); ok {
597 if typeof[kv.Key] == "" {
598 typeof[kv.Key] = kt
599 }
600 if typeof[kv.Value] == "" {
601 typeof[kv.Value] = vt
602 }
603 }
604 }
605 }
606 }
607 if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 {
608 for _, e := range n.Elts {
609 if kv, ok := e.(*ast.KeyValueExpr); ok {
610 if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
611 if typeof[kv.Value] == "" {
612 typeof[kv.Value] = ft
613 }
614 }
615 }
616 }
617 }
618
619 case *ast.ParenExpr:
620
621 typeof[n] = typeof[n.X]
622
623 case *ast.RangeStmt:
624 t := expand(typeof[n.X])
625 if t == "" {
626 return
627 }
628 var key, value string
629 if t == "string" {
630 key, value = "int", "rune"
631 } else if strings.HasPrefix(t, "[") {
632 key = "int"
633 if i := strings.Index(t, "]"); i >= 0 {
634 value = t[i+1:]
635 }
636 } else if strings.HasPrefix(t, "map[") {
637 if i := strings.Index(t, "]"); i >= 0 {
638 key, value = t[4:i], t[i+1:]
639 }
640 }
641 changed := false
642 if n.Key != nil && key != "" {
643 changed = true
644 set(n.Key, key, n.Tok == token.DEFINE)
645 }
646 if n.Value != nil && value != "" {
647 changed = true
648 set(n.Value, value, n.Tok == token.DEFINE)
649 }
650
651
652 if changed {
653 typecheck1(cfg, n.Body, typeof, assign)
654 }
655
656 case *ast.TypeSwitchStmt:
657
658
659
660
661 as, ok := n.Assign.(*ast.AssignStmt)
662 if !ok {
663 return
664 }
665 varx, ok := as.Lhs[0].(*ast.Ident)
666 if !ok {
667 return
668 }
669 t := typeof[varx]
670 for _, cas := range n.Body.List {
671 cas := cas.(*ast.CaseClause)
672 if len(cas.List) == 1 {
673
674
675 if tt := typeof[cas.List[0]]; isType(tt) {
676 tt = getType(tt)
677 typeof[varx] = tt
678 typeof[varx.Obj] = tt
679 typecheck1(cfg, cas.Body, typeof, assign)
680 }
681 }
682 }
683
684 typeof[varx] = t
685 typeof[varx.Obj] = t
686
687 case *ast.ReturnStmt:
688 if len(curfn) == 0 {
689
690 return
691 }
692 f := curfn[len(curfn)-1]
693 res := n.Results
694 if f.Results != nil {
695 t := split(typeof[f.Results])
696 for i := 0; i < len(res) && i < len(t); i++ {
697 set(res[i], t[i], false)
698 }
699 }
700
701 case *ast.BinaryExpr:
702
703 switch n.Op {
704 case token.EQL, token.NEQ:
705 if typeof[n.X] != "" && typeof[n.Y] == "" {
706 typeof[n.Y] = typeof[n.X]
707 }
708 if typeof[n.X] == "" && typeof[n.Y] != "" {
709 typeof[n.X] = typeof[n.Y]
710 }
711 }
712 }
713 }
714 walkBeforeAfter(f, before, after)
715 }
716
717
718
719
720
721
722
723 func splitFunc(s string) (in, out []string) {
724 if !strings.HasPrefix(s, "func(") {
725 return nil, nil
726 }
727
728 i := len("func(")
729 nparen := 0
730 for j := i; j < len(s); j++ {
731 switch s[j] {
732 case '(':
733 nparen++
734 case ')':
735 nparen--
736 if nparen < 0 {
737
738 out := strings.TrimSpace(s[j+1:])
739 if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
740 out = out[1 : len(out)-1]
741 }
742 return split(s[i:j]), split(out)
743 }
744 }
745 }
746 return nil, nil
747 }
748
749
750 func joinFunc(in, out []string) string {
751 outs := ""
752 if len(out) == 1 {
753 outs = " " + out[0]
754 } else if len(out) > 1 {
755 outs = " (" + join(out) + ")"
756 }
757 return "func(" + join(in) + ")" + outs
758 }
759
760
761 func split(s string) []string {
762 out := []string{}
763 i := 0
764 nparen := 0
765 for j := 0; j < len(s); j++ {
766 switch s[j] {
767 case ' ':
768 if i == j {
769 i++
770 }
771 case '(':
772 nparen++
773 case ')':
774 nparen--
775 if nparen < 0 {
776
777 return nil
778 }
779 case ',':
780 if nparen == 0 {
781 if i < j {
782 out = append(out, s[i:j])
783 }
784 i = j + 1
785 }
786 }
787 }
788 if nparen != 0 {
789
790 return nil
791 }
792 if i < len(s) {
793 out = append(out, s[i:])
794 }
795 return out
796 }
797
798
799 func join(x []string) string {
800 return strings.Join(x, ", ")
801 }
802
View as plain text