Source file src/cmd/fix/main.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "io/ioutil"
17 "os"
18 "os/exec"
19 "path/filepath"
20 "runtime"
21 "sort"
22 "strings"
23 )
24
25 var (
26 fset = token.NewFileSet()
27 exitCode = 0
28 )
29
30 var allowedRewrites = flag.String("r", "",
31 "restrict the rewrites to this comma-separated list")
32
33 var forceRewrites = flag.String("force", "",
34 "force these fixes to run even if the code looks updated")
35
36 var allowed, force map[string]bool
37
38 var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
39
40
41 const debug = false
42
43 func usage() {
44 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
45 flag.PrintDefaults()
46 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
47 sort.Sort(byName(fixes))
48 for _, f := range fixes {
49 if f.disabled {
50 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
51 } else {
52 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
53 }
54 desc := strings.TrimSpace(f.desc)
55 desc = strings.ReplaceAll(desc, "\n", "\n\t")
56 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
57 }
58 os.Exit(2)
59 }
60
61 func main() {
62 flag.Usage = usage
63 flag.Parse()
64
65 sort.Sort(byDate(fixes))
66
67 if *allowedRewrites != "" {
68 allowed = make(map[string]bool)
69 for _, f := range strings.Split(*allowedRewrites, ",") {
70 allowed[f] = true
71 }
72 }
73
74 if *forceRewrites != "" {
75 force = make(map[string]bool)
76 for _, f := range strings.Split(*forceRewrites, ",") {
77 force[f] = true
78 }
79 }
80
81 if flag.NArg() == 0 {
82 if err := processFile("standard input", true); err != nil {
83 report(err)
84 }
85 os.Exit(exitCode)
86 }
87
88 for i := 0; i < flag.NArg(); i++ {
89 path := flag.Arg(i)
90 switch dir, err := os.Stat(path); {
91 case err != nil:
92 report(err)
93 case dir.IsDir():
94 walkDir(path)
95 default:
96 if err := processFile(path, false); err != nil {
97 report(err)
98 }
99 }
100 }
101
102 os.Exit(exitCode)
103 }
104
105 const parserMode = parser.ParseComments
106
107 func gofmtFile(f *ast.File) ([]byte, error) {
108 var buf bytes.Buffer
109 if err := format.Node(&buf, fset, f); err != nil {
110 return nil, err
111 }
112 return buf.Bytes(), nil
113 }
114
115 func processFile(filename string, useStdin bool) error {
116 var f *os.File
117 var err error
118 var fixlog bytes.Buffer
119
120 if useStdin {
121 f = os.Stdin
122 } else {
123 f, err = os.Open(filename)
124 if err != nil {
125 return err
126 }
127 defer f.Close()
128 }
129
130 src, err := ioutil.ReadAll(f)
131 if err != nil {
132 return err
133 }
134
135 file, err := parser.ParseFile(fset, filename, src, parserMode)
136 if err != nil {
137 return err
138 }
139
140
141 newFile := file
142 fixed := false
143 for _, fix := range fixes {
144 if allowed != nil && !allowed[fix.name] {
145 continue
146 }
147 if fix.disabled && !force[fix.name] {
148 continue
149 }
150 if fix.f(newFile) {
151 fixed = true
152 fmt.Fprintf(&fixlog, " %s", fix.name)
153
154
155
156
157 newSrc, err := gofmtFile(newFile)
158 if err != nil {
159 return err
160 }
161 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
162 if err != nil {
163 if debug {
164 fmt.Printf("%s", newSrc)
165 report(err)
166 os.Exit(exitCode)
167 }
168 return err
169 }
170 }
171 }
172 if !fixed {
173 return nil
174 }
175 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
176
177
178
179
180
181
182
183 newSrc, err := gofmtFile(newFile)
184 if err != nil {
185 return err
186 }
187
188 if *doDiff {
189 data, err := diff(src, newSrc)
190 if err != nil {
191 return fmt.Errorf("computing diff: %s", err)
192 }
193 fmt.Printf("diff %s fixed/%s\n", filename, filename)
194 os.Stdout.Write(data)
195 return nil
196 }
197
198 if useStdin {
199 os.Stdout.Write(newSrc)
200 return nil
201 }
202
203 return ioutil.WriteFile(f.Name(), newSrc, 0)
204 }
205
206 var gofmtBuf bytes.Buffer
207
208 func gofmt(n interface{}) string {
209 gofmtBuf.Reset()
210 if err := format.Node(&gofmtBuf, fset, n); err != nil {
211 return "<" + err.Error() + ">"
212 }
213 return gofmtBuf.String()
214 }
215
216 func report(err error) {
217 scanner.PrintError(os.Stderr, err)
218 exitCode = 2
219 }
220
221 func walkDir(path string) {
222 filepath.Walk(path, visitFile)
223 }
224
225 func visitFile(path string, f os.FileInfo, err error) error {
226 if err == nil && isGoFile(f) {
227 err = processFile(path, false)
228 }
229 if err != nil {
230 report(err)
231 }
232 return nil
233 }
234
235 func isGoFile(f os.FileInfo) bool {
236
237 name := f.Name()
238 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
239 }
240
241 func writeTempFile(dir, prefix string, data []byte) (string, error) {
242 file, err := ioutil.TempFile(dir, prefix)
243 if err != nil {
244 return "", err
245 }
246 _, err = file.Write(data)
247 if err1 := file.Close(); err == nil {
248 err = err1
249 }
250 if err != nil {
251 os.Remove(file.Name())
252 return "", err
253 }
254 return file.Name(), nil
255 }
256
257 func diff(b1, b2 []byte) (data []byte, err error) {
258 f1, err := writeTempFile("", "go-fix", b1)
259 if err != nil {
260 return
261 }
262 defer os.Remove(f1)
263
264 f2, err := writeTempFile("", "go-fix", b2)
265 if err != nil {
266 return
267 }
268 defer os.Remove(f2)
269
270 cmd := "diff"
271 if runtime.GOOS == "plan9" {
272 cmd = "/bin/ape/diff"
273 }
274
275 data, err = exec.Command(cmd, "-u", f1, f2).CombinedOutput()
276 if len(data) > 0 {
277
278
279 err = nil
280 }
281 return
282 }
283
View as plain text