...
Source file src/sort/genzfunc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package main
16
17 import (
18 "bytes"
19 "go/ast"
20 "go/format"
21 "go/parser"
22 "go/token"
23 "io/ioutil"
24 "log"
25 "regexp"
26 )
27
28 var fset = token.NewFileSet()
29
30 func main() {
31 af, err := parser.ParseFile(fset, "sort.go", nil, 0)
32 if err != nil {
33 log.Fatal(err)
34 }
35 af.Doc = nil
36 af.Imports = nil
37 af.Comments = nil
38
39 var newDecl []ast.Decl
40 for _, d := range af.Decls {
41 fd, ok := d.(*ast.FuncDecl)
42 if !ok {
43 continue
44 }
45 if fd.Recv != nil || fd.Name.IsExported() {
46 continue
47 }
48 typ := fd.Type
49 if len(typ.Params.List) < 1 {
50 continue
51 }
52 arg0 := typ.Params.List[0]
53 arg0Name := arg0.Names[0].Name
54 arg0Type := arg0.Type.(*ast.Ident)
55 if arg0Name != "data" || arg0Type.Name != "Interface" {
56 continue
57 }
58 arg0Type.Name = "lessSwap"
59
60 newDecl = append(newDecl, fd)
61 }
62 af.Decls = newDecl
63 ast.Walk(visitFunc(rewriteCalls), af)
64
65 var out bytes.Buffer
66 if err := format.Node(&out, fset, af); err != nil {
67 log.Fatalf("format.Node: %v", err)
68 }
69
70
71 src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
72
73
74
75
76 src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
77
78
79 src, err = format.Source(src)
80 if err != nil {
81 log.Fatalf("format.Source: %v on\n%s", err, src)
82 }
83
84 out.Reset()
85 out.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT.
86
87 // Copyright 2016 The Go Authors. All rights reserved.
88 // Use of this source code is governed by a BSD-style
89 // license that can be found in the LICENSE file.
90
91 `)
92 out.Write(src)
93
94 const target = "zfuncversion.go"
95 if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
96 log.Fatal(err)
97 }
98 }
99
100 type visitFunc func(ast.Node) ast.Visitor
101
102 func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
103
104 func rewriteCalls(n ast.Node) ast.Visitor {
105 ce, ok := n.(*ast.CallExpr)
106 if ok {
107 rewriteCall(ce)
108 }
109 return visitFunc(rewriteCalls)
110 }
111
112 func rewriteCall(ce *ast.CallExpr) {
113 ident, ok := ce.Fun.(*ast.Ident)
114 if !ok {
115
116 return
117 }
118
119 if ident.Name == "int" || ident.Name == "uint" {
120 return
121 }
122 if len(ce.Args) < 1 {
123 return
124 }
125 ident.Name += "_func"
126 }
127
View as plain text