...

Source file src/cmd/fix/main.go

     1	// Copyright 2011 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package main
     6	
     7	import (
     8		"bytes"
     9		"flag"
    10		"fmt"
    11		"go/ast"
    12		"go/format"
    13		"go/parser"
    14		"go/scanner"
    15		"go/token"
    16		"io/ioutil"
    17		"os"
    18		"os/exec"
    19		"path/filepath"
    20		"runtime"
    21		"sort"
    22		"strings"
    23	)
    24	
    25	var (
    26		fset     = token.NewFileSet()
    27		exitCode = 0
    28	)
    29	
    30	var allowedRewrites = flag.String("r", "",
    31		"restrict the rewrites to this comma-separated list")
    32	
    33	var forceRewrites = flag.String("force", "",
    34		"force these fixes to run even if the code looks updated")
    35	
    36	var allowed, force map[string]bool
    37	
    38	var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
    39	
    40	// enable for debugging fix failures
    41	const debug = false // display incorrectly reformatted source and exit
    42	
    43	func usage() {
    44		fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    45		flag.PrintDefaults()
    46		fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    47		sort.Sort(byName(fixes))
    48		for _, f := range fixes {
    49			if f.disabled {
    50				fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    51			} else {
    52				fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    53			}
    54			desc := strings.TrimSpace(f.desc)
    55			desc = strings.ReplaceAll(desc, "\n", "\n\t")
    56			fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    57		}
    58		os.Exit(2)
    59	}
    60	
    61	func main() {
    62		flag.Usage = usage
    63		flag.Parse()
    64	
    65		sort.Sort(byDate(fixes))
    66	
    67		if *allowedRewrites != "" {
    68			allowed = make(map[string]bool)
    69			for _, f := range strings.Split(*allowedRewrites, ",") {
    70				allowed[f] = true
    71			}
    72		}
    73	
    74		if *forceRewrites != "" {
    75			force = make(map[string]bool)
    76			for _, f := range strings.Split(*forceRewrites, ",") {
    77				force[f] = true
    78			}
    79		}
    80	
    81		if flag.NArg() == 0 {
    82			if err := processFile("standard input", true); err != nil {
    83				report(err)
    84			}
    85			os.Exit(exitCode)
    86		}
    87	
    88		for i := 0; i < flag.NArg(); i++ {
    89			path := flag.Arg(i)
    90			switch dir, err := os.Stat(path); {
    91			case err != nil:
    92				report(err)
    93			case dir.IsDir():
    94				walkDir(path)
    95			default:
    96				if err := processFile(path, false); err != nil {
    97					report(err)
    98				}
    99			}
   100		}
   101	
   102		os.Exit(exitCode)
   103	}
   104	
   105	const parserMode = parser.ParseComments
   106	
   107	func gofmtFile(f *ast.File) ([]byte, error) {
   108		var buf bytes.Buffer
   109		if err := format.Node(&buf, fset, f); err != nil {
   110			return nil, err
   111		}
   112		return buf.Bytes(), nil
   113	}
   114	
   115	func processFile(filename string, useStdin bool) error {
   116		var f *os.File
   117		var err error
   118		var fixlog bytes.Buffer
   119	
   120		if useStdin {
   121			f = os.Stdin
   122		} else {
   123			f, err = os.Open(filename)
   124			if err != nil {
   125				return err
   126			}
   127			defer f.Close()
   128		}
   129	
   130		src, err := ioutil.ReadAll(f)
   131		if err != nil {
   132			return err
   133		}
   134	
   135		file, err := parser.ParseFile(fset, filename, src, parserMode)
   136		if err != nil {
   137			return err
   138		}
   139	
   140		// Apply all fixes to file.
   141		newFile := file
   142		fixed := false
   143		for _, fix := range fixes {
   144			if allowed != nil && !allowed[fix.name] {
   145				continue
   146			}
   147			if fix.disabled && !force[fix.name] {
   148				continue
   149			}
   150			if fix.f(newFile) {
   151				fixed = true
   152				fmt.Fprintf(&fixlog, " %s", fix.name)
   153	
   154				// AST changed.
   155				// Print and parse, to update any missing scoping
   156				// or position information for subsequent fixers.
   157				newSrc, err := gofmtFile(newFile)
   158				if err != nil {
   159					return err
   160				}
   161				newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   162				if err != nil {
   163					if debug {
   164						fmt.Printf("%s", newSrc)
   165						report(err)
   166						os.Exit(exitCode)
   167					}
   168					return err
   169				}
   170			}
   171		}
   172		if !fixed {
   173			return nil
   174		}
   175		fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   176	
   177		// Print AST.  We did that after each fix, so this appears
   178		// redundant, but it is necessary to generate gofmt-compatible
   179		// source code in a few cases. The official gofmt style is the
   180		// output of the printer run on a standard AST generated by the parser,
   181		// but the source we generated inside the loop above is the
   182		// output of the printer run on a mangled AST generated by a fixer.
   183		newSrc, err := gofmtFile(newFile)
   184		if err != nil {
   185			return err
   186		}
   187	
   188		if *doDiff {
   189			data, err := diff(src, newSrc)
   190			if err != nil {
   191				return fmt.Errorf("computing diff: %s", err)
   192			}
   193			fmt.Printf("diff %s fixed/%s\n", filename, filename)
   194			os.Stdout.Write(data)
   195			return nil
   196		}
   197	
   198		if useStdin {
   199			os.Stdout.Write(newSrc)
   200			return nil
   201		}
   202	
   203		return ioutil.WriteFile(f.Name(), newSrc, 0)
   204	}
   205	
   206	var gofmtBuf bytes.Buffer
   207	
   208	func gofmt(n interface{}) string {
   209		gofmtBuf.Reset()
   210		if err := format.Node(&gofmtBuf, fset, n); err != nil {
   211			return "<" + err.Error() + ">"
   212		}
   213		return gofmtBuf.String()
   214	}
   215	
   216	func report(err error) {
   217		scanner.PrintError(os.Stderr, err)
   218		exitCode = 2
   219	}
   220	
   221	func walkDir(path string) {
   222		filepath.Walk(path, visitFile)
   223	}
   224	
   225	func visitFile(path string, f os.FileInfo, err error) error {
   226		if err == nil && isGoFile(f) {
   227			err = processFile(path, false)
   228		}
   229		if err != nil {
   230			report(err)
   231		}
   232		return nil
   233	}
   234	
   235	func isGoFile(f os.FileInfo) bool {
   236		// ignore non-Go files
   237		name := f.Name()
   238		return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   239	}
   240	
   241	func writeTempFile(dir, prefix string, data []byte) (string, error) {
   242		file, err := ioutil.TempFile(dir, prefix)
   243		if err != nil {
   244			return "", err
   245		}
   246		_, err = file.Write(data)
   247		if err1 := file.Close(); err == nil {
   248			err = err1
   249		}
   250		if err != nil {
   251			os.Remove(file.Name())
   252			return "", err
   253		}
   254		return file.Name(), nil
   255	}
   256	
   257	func diff(b1, b2 []byte) (data []byte, err error) {
   258		f1, err := writeTempFile("", "go-fix", b1)
   259		if err != nil {
   260			return
   261		}
   262		defer os.Remove(f1)
   263	
   264		f2, err := writeTempFile("", "go-fix", b2)
   265		if err != nil {
   266			return
   267		}
   268		defer os.Remove(f2)
   269	
   270		cmd := "diff"
   271		if runtime.GOOS == "plan9" {
   272			cmd = "/bin/ape/diff"
   273		}
   274	
   275		data, err = exec.Command(cmd, "-u", f1, f2).CombinedOutput()
   276		if len(data) > 0 {
   277			// diff exits with a non-zero status when the files don't match.
   278			// Ignore that failure as long as we get output.
   279			err = nil
   280		}
   281		return
   282	}
   283	

View as plain text