blob: dc5307a15d027d88db92f0a17753845f525c6956 [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testinggoroutine
import (
_ "embed"
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil"
)
//go:embed doc.go
var doc string
var reportSubtest bool
func init() {
Analyzer.Flags.BoolVar(&reportSubtest, "subtest", false, "whether to check if t.Run subtest is terminated correctly; experimental")
}
var Analyzer = &analysis.Analyzer{
Name: "testinggoroutine",
Doc: analysisutil.MustExtractDoc(doc, "testinggoroutine"),
URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/testinggoroutine",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
}
func run(pass *analysis.Pass) (interface{}, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
if !analysisutil.Imports(pass.Pkg, "testing") {
return nil, nil
}
toDecl := localFunctionDecls(pass.TypesInfo, pass.Files)
// asyncs maps nodes whose statements will be executed concurrently
// with respect to some test function, to the call sites where they
// are invoked asynchronously. There may be multiple such call sites
// for e.g. test helpers.
asyncs := make(map[ast.Node][]*asyncCall)
var regions []ast.Node
addCall := func(c *asyncCall) {
if c != nil {
r := c.region
if asyncs[r] == nil {
regions = append(regions, r)
}
asyncs[r] = append(asyncs[r], c)
}
}
// Collect all of the go callee() and t.Run(name, callee) extents.
inspect.Nodes([]ast.Node{
(*ast.FuncDecl)(nil),
(*ast.GoStmt)(nil),
(*ast.CallExpr)(nil),
}, func(node ast.Node, push bool) bool {
if !push {
return false
}
switch node := node.(type) {
case *ast.FuncDecl:
return hasBenchmarkOrTestParams(node)
case *ast.GoStmt:
c := goAsyncCall(pass.TypesInfo, node, toDecl)
addCall(c)
case *ast.CallExpr:
c := tRunAsyncCall(pass.TypesInfo, node)
addCall(c)
}
return true
})
// Check for t.Forbidden() calls within each region r that is a
// callee in some go r() or a t.Run("name", r).
//
// Also considers a special case when r is a go t.Forbidden() call.
for _, region := range regions {
ast.Inspect(region, func(n ast.Node) bool {
if n == region {
return true // always descend into the region itself.
} else if asyncs[n] != nil {
return false // will be visited by another region.
}
call, ok := n.(*ast.CallExpr)
if !ok {
return true
}
x, sel, fn := forbiddenMethod(pass.TypesInfo, call)
if x == nil {
return true
}
for _, e := range asyncs[region] {
if !withinScope(e.scope, x) {
forbidden := formatMethod(sel, fn) // e.g. "(*testing.T).Forbidden
var context string
var where analysis.Range = e.async // Put the report at the go fun() or t.Run(name, fun).
if _, local := e.fun.(*ast.FuncLit); local {
where = call // Put the report at the t.Forbidden() call.
} else if id, ok := e.fun.(*ast.Ident); ok {
context = fmt.Sprintf(" (%s calls %s)", id.Name, forbidden)
}
if _, ok := e.async.(*ast.GoStmt); ok {
pass.ReportRangef(where, "call to %s from a non-test goroutine%s", forbidden, context)
} else if reportSubtest {
pass.ReportRangef(where, "call to %s on %s defined outside of the subtest%s", forbidden, x.Name(), context)
}
}
}
return true
})
}
return nil, nil
}
func hasBenchmarkOrTestParams(fnDecl *ast.FuncDecl) bool {
// Check that the function's arguments include "*testing.T" or "*testing.B".
params := fnDecl.Type.Params.List
for _, param := range params {
if _, ok := typeIsTestingDotTOrB(param.Type); ok {
return true
}
}
return false
}
func typeIsTestingDotTOrB(expr ast.Expr) (string, bool) {
starExpr, ok := expr.(*ast.StarExpr)
if !ok {
return "", false
}
selExpr, ok := starExpr.X.(*ast.SelectorExpr)
if !ok {
return "", false
}
varPkg := selExpr.X.(*ast.Ident)
if varPkg.Name != "testing" {
return "", false
}
varTypeName := selExpr.Sel.Name
ok = varTypeName == "B" || varTypeName == "T"
return varTypeName, ok
}
// asyncCall describes a region of code that needs to be checked for
// t.Forbidden() calls as it is started asynchronously from an async
// node go fun() or t.Run(name, fun).
type asyncCall struct {
region ast.Node // region of code to check for t.Forbidden() calls.
async ast.Node // *ast.GoStmt or *ast.CallExpr (for t.Run)
scope ast.Node // Report t.Forbidden() if t is not declared within scope.
fun ast.Expr // fun in go fun() or t.Run(name, fun)
}
// withinScope returns true if x.Pos() is in [scope.Pos(), scope.End()].
func withinScope(scope ast.Node, x *types.Var) bool {
if scope != nil {
return x.Pos() != token.NoPos && scope.Pos() <= x.Pos() && x.Pos() <= scope.End()
}
return false
}
// goAsyncCall returns the extent of a call from a go fun() statement.
func goAsyncCall(info *types.Info, goStmt *ast.GoStmt, toDecl func(*types.Func) *ast.FuncDecl) *asyncCall {
call := goStmt.Call
fun := astutil.Unparen(call.Fun)
if id := funcIdent(fun); id != nil {
if lit := funcLitInScope(id); lit != nil {
return &asyncCall{region: lit, async: goStmt, scope: nil, fun: fun}
}
}
if fn := typeutil.StaticCallee(info, call); fn != nil { // static call or method in the package?
if decl := toDecl(fn); decl != nil {
return &asyncCall{region: decl, async: goStmt, scope: nil, fun: fun}
}
}
// Check go statement for go t.Forbidden() or go func(){t.Forbidden()}().
return &asyncCall{region: goStmt, async: goStmt, scope: nil, fun: fun}
}
// tRunAsyncCall returns the extent of a call from a t.Run("name", fun) expression.
func tRunAsyncCall(info *types.Info, call *ast.CallExpr) *asyncCall {
if len(call.Args) != 2 {
return nil
}
run := typeutil.Callee(info, call)
if run, ok := run.(*types.Func); !ok || !isMethodNamed(run, "testing", "Run") {
return nil
}
fun := astutil.Unparen(call.Args[1])
if lit, ok := fun.(*ast.FuncLit); ok { // function lit?
return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
}
if id := funcIdent(fun); id != nil {
if lit := funcLitInScope(id); lit != nil { // function lit in variable?
return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
}
}
// Check within t.Run(name, fun) for calls to t.Forbidden,
// e.g. t.Run(name, func(t *testing.T){ t.Forbidden() })
return &asyncCall{region: call, async: call, scope: fun, fun: fun}
}
var forbidden = []string{
"FailNow",
"Fatal",
"Fatalf",
"Skip",
"Skipf",
"SkipNow",
}
// forbiddenMethod decomposes a call x.m() into (x, x.m, m) where
// x is a variable, x.m is a selection, and m is the static callee m.
// Returns (nil, nil, nil) if call is not of this form.
func forbiddenMethod(info *types.Info, call *ast.CallExpr) (*types.Var, *types.Selection, *types.Func) {
// Compare to typeutil.StaticCallee.
fun := astutil.Unparen(call.Fun)
selExpr, ok := fun.(*ast.SelectorExpr)
if !ok {
return nil, nil, nil
}
sel := info.Selections[selExpr]
if sel == nil {
return nil, nil, nil
}
var x *types.Var
if id, ok := astutil.Unparen(selExpr.X).(*ast.Ident); ok {
x, _ = info.Uses[id].(*types.Var)
}
if x == nil {
return nil, nil, nil
}
fn, _ := sel.Obj().(*types.Func)
if fn == nil || !isMethodNamed(fn, "testing", forbidden...) {
return nil, nil, nil
}
return x, sel, fn
}
func formatMethod(sel *types.Selection, fn *types.Func) string {
var ptr string
rtype := sel.Recv()
if p, ok := rtype.(*types.Pointer); ok {
ptr = "*"
rtype = p.Elem()
}
return fmt.Sprintf("(%s%s).%s", ptr, rtype.String(), fn.Name())
}