Source file src/pkg/cmd/vendor/golang.org/x/tools/go/analysis/passes/lostcancel/lostcancel.go
1
2
3
4
5
6
7 package lostcancel
8
9 import (
10 "fmt"
11 "go/ast"
12 "go/types"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/ctrlflow"
16 "golang.org/x/tools/go/analysis/passes/inspect"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/cfg"
19 )
20
21 const Doc = `check cancel func returned by context.WithCancel is called
22
23 The cancelation function returned by context.WithCancel, WithTimeout,
24 and WithDeadline must be called or the new context will remain live
25 until its parent context is cancelled.
26 (The background context is never cancelled.)`
27
28 var Analyzer = &analysis.Analyzer{
29 Name: "lostcancel",
30 Doc: Doc,
31 Run: run,
32 Requires: []*analysis.Analyzer{
33 inspect.Analyzer,
34 ctrlflow.Analyzer,
35 },
36 }
37
38 const debug = false
39
40 var contextPackage = "context"
41
42
43
44
45
46
47
48
49
50
51
52 func run(pass *analysis.Pass) (interface{}, error) {
53
54 if !hasImport(pass.Pkg, contextPackage) {
55 return nil, nil
56 }
57
58
59 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
60 nodeTypes := []ast.Node{
61 (*ast.FuncLit)(nil),
62 (*ast.FuncDecl)(nil),
63 }
64 inspect.Preorder(nodeTypes, func(n ast.Node) {
65 runFunc(pass, n)
66 })
67 return nil, nil
68 }
69
70 func runFunc(pass *analysis.Pass, node ast.Node) {
71
72 var funcScope *types.Scope
73 switch v := node.(type) {
74 case *ast.FuncLit:
75 funcScope = pass.TypesInfo.Scopes[v.Type]
76 case *ast.FuncDecl:
77 funcScope = pass.TypesInfo.Scopes[v.Type]
78 }
79
80
81 cancelvars := make(map[*types.Var]ast.Node)
82
83
84
85
86
87
88 stack := make([]ast.Node, 0, 32)
89 ast.Inspect(node, func(n ast.Node) bool {
90 switch n.(type) {
91 case *ast.FuncLit:
92 if len(stack) > 0 {
93 return false
94 }
95 case nil:
96 stack = stack[:len(stack)-1]
97 return true
98 }
99 stack = append(stack, n)
100
101
102
103
104
105
106
107 if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
108 return true
109 }
110 var id *ast.Ident
111 stmt := stack[len(stack)-3]
112 switch stmt := stmt.(type) {
113 case *ast.ValueSpec:
114 if len(stmt.Names) > 1 {
115 id = stmt.Names[1]
116 }
117 case *ast.AssignStmt:
118 if len(stmt.Lhs) > 1 {
119 id, _ = stmt.Lhs[1].(*ast.Ident)
120 }
121 }
122 if id != nil {
123 if id.Name == "_" {
124 pass.Reportf(id.Pos(),
125 "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
126 n.(*ast.SelectorExpr).Sel.Name)
127 } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
128
129
130 if funcScope.Contains(v.Pos()) {
131 cancelvars[v] = stmt
132 }
133 } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
134 cancelvars[v] = stmt
135 }
136 }
137 return true
138 })
139
140 if len(cancelvars) == 0 {
141 return
142 }
143
144
145 cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
146 var g *cfg.CFG
147 var sig *types.Signature
148 switch node := node.(type) {
149 case *ast.FuncDecl:
150 sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
151 if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
152
153
154 return
155 }
156 g = cfgs.FuncDecl(node)
157
158 case *ast.FuncLit:
159 sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
160 g = cfgs.FuncLit(node)
161 }
162 if sig == nil {
163 return
164 }
165
166
167 if debug {
168 fmt.Println(g.Format(pass.Fset))
169 }
170
171
172
173
174 for v, stmt := range cancelvars {
175 if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
176 lineno := pass.Fset.Position(stmt.Pos()).Line
177 pass.Reportf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
178 pass.Reportf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
179 }
180 }
181 }
182
183 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
184
185 func hasImport(pkg *types.Package, path string) bool {
186 for _, imp := range pkg.Imports() {
187 if imp.Path() == path {
188 return true
189 }
190 }
191 return false
192 }
193
194
195
196 func isContextWithCancel(info *types.Info, n ast.Node) bool {
197 sel, ok := n.(*ast.SelectorExpr)
198 if !ok {
199 return false
200 }
201 switch sel.Sel.Name {
202 case "WithCancel", "WithTimeout", "WithDeadline":
203 default:
204 return false
205 }
206 if x, ok := sel.X.(*ast.Ident); ok {
207 if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
208 return pkgname.Imported().Path() == contextPackage
209 }
210
211
212 return x.Name == "context"
213 }
214 return false
215 }
216
217
218
219
220
221 func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
222 vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
223
224
225 uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
226 found := false
227 for _, stmt := range stmts {
228 ast.Inspect(stmt, func(n ast.Node) bool {
229 switch n := n.(type) {
230 case *ast.Ident:
231 if pass.TypesInfo.Uses[n] == v {
232 found = true
233 }
234 case *ast.ReturnStmt:
235
236
237 if n.Results == nil && vIsNamedResult {
238 found = true
239 }
240 }
241 return !found
242 })
243 }
244 return found
245 }
246
247
248 memo := make(map[*cfg.Block]bool)
249 blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
250 res, ok := memo[b]
251 if !ok {
252 res = uses(pass, v, b.Nodes)
253 memo[b] = res
254 }
255 return res
256 }
257
258
259
260 var defblock *cfg.Block
261 var rest []ast.Node
262 outer:
263 for _, b := range g.Blocks {
264 for i, n := range b.Nodes {
265 if n == stmt {
266 defblock = b
267 rest = b.Nodes[i+1:]
268 break outer
269 }
270 }
271 }
272 if defblock == nil {
273 panic("internal error: can't find defining block for cancel var")
274 }
275
276
277 if uses(pass, v, rest) {
278 return nil
279 }
280
281
282 if ret := defblock.Return(); ret != nil {
283 return ret
284 }
285
286
287
288 seen := make(map[*cfg.Block]bool)
289 var search func(blocks []*cfg.Block) *ast.ReturnStmt
290 search = func(blocks []*cfg.Block) *ast.ReturnStmt {
291 for _, b := range blocks {
292 if seen[b] {
293 continue
294 }
295 seen[b] = true
296
297
298 if blockUses(pass, v, b) {
299 continue
300 }
301
302
303 if ret := b.Return(); ret != nil {
304 if debug {
305 fmt.Printf("found path to return in block %s\n", b)
306 }
307 return ret
308 }
309
310
311 if ret := search(b.Succs); ret != nil {
312 if debug {
313 fmt.Printf(" from block %s\n", b)
314 }
315 return ret
316 }
317 }
318 return nil
319 }
320 return search(defblock.Succs)
321 }
322
323 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
324 for i := 0; i < tuple.Len(); i++ {
325 if tuple.At(i) == v {
326 return true
327 }
328 }
329 return false
330 }
331
View as plain text