...

Source file src/cmd/gofmt/rewrite.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package main
     6	
     7	import (
     8		"fmt"
     9		"go/ast"
    10		"go/parser"
    11		"go/token"
    12		"os"
    13		"reflect"
    14		"strings"
    15		"unicode"
    16		"unicode/utf8"
    17	)
    18	
    19	func initRewrite() {
    20		if *rewriteRule == "" {
    21			rewrite = nil // disable any previous rewrite
    22			return
    23		}
    24		f := strings.Split(*rewriteRule, "->")
    25		if len(f) != 2 {
    26			fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
    27			os.Exit(2)
    28		}
    29		pattern := parseExpr(f[0], "pattern")
    30		replace := parseExpr(f[1], "replacement")
    31		rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
    32	}
    33	
    34	// parseExpr parses s as an expression.
    35	// It might make sense to expand this to allow statement patterns,
    36	// but there are problems with preserving formatting and also
    37	// with what a wildcard for a statement looks like.
    38	func parseExpr(s, what string) ast.Expr {
    39		x, err := parser.ParseExpr(s)
    40		if err != nil {
    41			fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
    42			os.Exit(2)
    43		}
    44		return x
    45	}
    46	
    47	// Keep this function for debugging.
    48	/*
    49	func dump(msg string, val reflect.Value) {
    50		fmt.Printf("%s:\n", msg)
    51		ast.Print(fileSet, val.Interface())
    52		fmt.Println()
    53	}
    54	*/
    55	
    56	// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
    57	func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
    58		cmap := ast.NewCommentMap(fileSet, p, p.Comments)
    59		m := make(map[string]reflect.Value)
    60		pat := reflect.ValueOf(pattern)
    61		repl := reflect.ValueOf(replace)
    62	
    63		var rewriteVal func(val reflect.Value) reflect.Value
    64		rewriteVal = func(val reflect.Value) reflect.Value {
    65			// don't bother if val is invalid to start with
    66			if !val.IsValid() {
    67				return reflect.Value{}
    68			}
    69			val = apply(rewriteVal, val)
    70			for k := range m {
    71				delete(m, k)
    72			}
    73			if match(m, pat, val) {
    74				val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
    75			}
    76			return val
    77		}
    78	
    79		r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
    80		r.Comments = cmap.Filter(r).Comments() // recreate comments list
    81		return r
    82	}
    83	
    84	// set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
    85	func set(x, y reflect.Value) {
    86		// don't bother if x cannot be set or y is invalid
    87		if !x.CanSet() || !y.IsValid() {
    88			return
    89		}
    90		defer func() {
    91			if x := recover(); x != nil {
    92				if s, ok := x.(string); ok &&
    93					(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
    94					// x cannot be set to y - ignore this rewrite
    95					return
    96				}
    97				panic(x)
    98			}
    99		}()
   100		x.Set(y)
   101	}
   102	
   103	// Values/types for special cases.
   104	var (
   105		objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
   106		scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
   107	
   108		identType     = reflect.TypeOf((*ast.Ident)(nil))
   109		objectPtrType = reflect.TypeOf((*ast.Object)(nil))
   110		positionType  = reflect.TypeOf(token.NoPos)
   111		callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
   112		scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
   113	)
   114	
   115	// apply replaces each AST field x in val with f(x), returning val.
   116	// To avoid extra conversions, f operates on the reflect.Value form.
   117	func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
   118		if !val.IsValid() {
   119			return reflect.Value{}
   120		}
   121	
   122		// *ast.Objects introduce cycles and are likely incorrect after
   123		// rewrite; don't follow them but replace with nil instead
   124		if val.Type() == objectPtrType {
   125			return objectPtrNil
   126		}
   127	
   128		// similarly for scopes: they are likely incorrect after a rewrite;
   129		// replace them with nil
   130		if val.Type() == scopePtrType {
   131			return scopePtrNil
   132		}
   133	
   134		switch v := reflect.Indirect(val); v.Kind() {
   135		case reflect.Slice:
   136			for i := 0; i < v.Len(); i++ {
   137				e := v.Index(i)
   138				set(e, f(e))
   139			}
   140		case reflect.Struct:
   141			for i := 0; i < v.NumField(); i++ {
   142				e := v.Field(i)
   143				set(e, f(e))
   144			}
   145		case reflect.Interface:
   146			e := v.Elem()
   147			set(v, f(e))
   148		}
   149		return val
   150	}
   151	
   152	func isWildcard(s string) bool {
   153		rune, size := utf8.DecodeRuneInString(s)
   154		return size == len(s) && unicode.IsLower(rune)
   155	}
   156	
   157	// match reports whether pattern matches val,
   158	// recording wildcard submatches in m.
   159	// If m == nil, match checks whether pattern == val.
   160	func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
   161		// Wildcard matches any expression. If it appears multiple
   162		// times in the pattern, it must match the same expression
   163		// each time.
   164		if m != nil && pattern.IsValid() && pattern.Type() == identType {
   165			name := pattern.Interface().(*ast.Ident).Name
   166			if isWildcard(name) && val.IsValid() {
   167				// wildcards only match valid (non-nil) expressions.
   168				if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
   169					if old, ok := m[name]; ok {
   170						return match(nil, old, val)
   171					}
   172					m[name] = val
   173					return true
   174				}
   175			}
   176		}
   177	
   178		// Otherwise, pattern and val must match recursively.
   179		if !pattern.IsValid() || !val.IsValid() {
   180			return !pattern.IsValid() && !val.IsValid()
   181		}
   182		if pattern.Type() != val.Type() {
   183			return false
   184		}
   185	
   186		// Special cases.
   187		switch pattern.Type() {
   188		case identType:
   189			// For identifiers, only the names need to match
   190			// (and none of the other *ast.Object information).
   191			// This is a common case, handle it all here instead
   192			// of recursing down any further via reflection.
   193			p := pattern.Interface().(*ast.Ident)
   194			v := val.Interface().(*ast.Ident)
   195			return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
   196		case objectPtrType, positionType:
   197			// object pointers and token positions always match
   198			return true
   199		case callExprType:
   200			// For calls, the Ellipsis fields (token.Position) must
   201			// match since that is how f(x) and f(x...) are different.
   202			// Check them here but fall through for the remaining fields.
   203			p := pattern.Interface().(*ast.CallExpr)
   204			v := val.Interface().(*ast.CallExpr)
   205			if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
   206				return false
   207			}
   208		}
   209	
   210		p := reflect.Indirect(pattern)
   211		v := reflect.Indirect(val)
   212		if !p.IsValid() || !v.IsValid() {
   213			return !p.IsValid() && !v.IsValid()
   214		}
   215	
   216		switch p.Kind() {
   217		case reflect.Slice:
   218			if p.Len() != v.Len() {
   219				return false
   220			}
   221			for i := 0; i < p.Len(); i++ {
   222				if !match(m, p.Index(i), v.Index(i)) {
   223					return false
   224				}
   225			}
   226			return true
   227	
   228		case reflect.Struct:
   229			for i := 0; i < p.NumField(); i++ {
   230				if !match(m, p.Field(i), v.Field(i)) {
   231					return false
   232				}
   233			}
   234			return true
   235	
   236		case reflect.Interface:
   237			return match(m, p.Elem(), v.Elem())
   238		}
   239	
   240		// Handle token integers, etc.
   241		return p.Interface() == v.Interface()
   242	}
   243	
   244	// subst returns a copy of pattern with values from m substituted in place
   245	// of wildcards and pos used as the position of tokens from the pattern.
   246	// if m == nil, subst returns a copy of pattern and doesn't change the line
   247	// number information.
   248	func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
   249		if !pattern.IsValid() {
   250			return reflect.Value{}
   251		}
   252	
   253		// Wildcard gets replaced with map value.
   254		if m != nil && pattern.Type() == identType {
   255			name := pattern.Interface().(*ast.Ident).Name
   256			if isWildcard(name) {
   257				if old, ok := m[name]; ok {
   258					return subst(nil, old, reflect.Value{})
   259				}
   260			}
   261		}
   262	
   263		if pos.IsValid() && pattern.Type() == positionType {
   264			// use new position only if old position was valid in the first place
   265			if old := pattern.Interface().(token.Pos); !old.IsValid() {
   266				return pattern
   267			}
   268			return pos
   269		}
   270	
   271		// Otherwise copy.
   272		switch p := pattern; p.Kind() {
   273		case reflect.Slice:
   274			v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
   275			for i := 0; i < p.Len(); i++ {
   276				v.Index(i).Set(subst(m, p.Index(i), pos))
   277			}
   278			return v
   279	
   280		case reflect.Struct:
   281			v := reflect.New(p.Type()).Elem()
   282			for i := 0; i < p.NumField(); i++ {
   283				v.Field(i).Set(subst(m, p.Field(i), pos))
   284			}
   285			return v
   286	
   287		case reflect.Ptr:
   288			v := reflect.New(p.Type()).Elem()
   289			if elem := p.Elem(); elem.IsValid() {
   290				v.Set(subst(m, elem, pos).Addr())
   291			}
   292			return v
   293	
   294		case reflect.Interface:
   295			v := reflect.New(p.Type()).Elem()
   296			if elem := p.Elem(); elem.IsValid() {
   297				v.Set(subst(m, elem, pos))
   298			}
   299			return v
   300		}
   301	
   302		return pattern
   303	}
   304	

View as plain text