Skip to content

Commit

Permalink
Merge pull request #7 from Buzzvil/improve_usability
Browse files Browse the repository at this point in the history
Simplify inspection steps
  • Loading branch information
dc7303 authored May 17, 2023
2 parents f4f3ec8 + 649a906 commit 99e488c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 149 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ go install github.com/Buzzvil/recovergoroutine
recovergoroutine -recover="" ./...

# -recover string
# Custom recover method name. Currently, it is difficult to determine
# if a CustomRecover function declared in another package is valid,
# so this option can be used to resolve it.
# Custom recovery method name. You can use this option
# when you want to call a method defined in a struct or
# use CustomRecover declared in an external package.
```

Check out the test cases for validation [examples](./test/src/faildata/failcode.go).
Expand Down
113 changes: 17 additions & 96 deletions recovergoroutine/recovergoroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ package recovergoroutine

import (
"flag"
"fmt"
"go/ast"
"go/parser"
"go/types"
"reflect"

"golang.org/x/tools/go/analysis"
)

type message string

var customRecover string

func NewAnalyzer() *analysis.Analyzer {
Expand All @@ -25,8 +22,7 @@ func NewAnalyzer() *analysis.Analyzer {
&customRecover,
"recover",
"",
"It is difficult to determine if a CustomRecover function declared in another package is valid,"+
" so this option can be used to resolve it.",
"You can use this option when you want to call a method defined in a struct or use CustomRecover declared in an external package.",
)

return analyzer
Expand All @@ -41,12 +37,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
return true
}

ok, err := safeGoStmt(goStmt, pass)
if err != nil {
runErr = err
return false
}

ok, msg := safeGoStmt(goStmt)
if ok {
return true
}
Expand All @@ -55,7 +46,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
Pos: goStmt.Pos(),
End: 0,
Category: "goroutine",
Message: "goroutine must have recover",
Message: string(msg),
})

return false
Expand All @@ -65,43 +56,28 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, runErr
}

func safeGoStmt(goStmt *ast.GoStmt, pass *analysis.Pass) (bool, error) {
func safeGoStmt(goStmt *ast.GoStmt) (bool, message) {
fn := goStmt.Call
switch fun := fn.Fun.(type) {
case *ast.SelectorExpr:
return safeSelectorExpr(fun, pass, safeFunc)
case *ast.FuncLit:
return safeFunc(fun, pass)
case *ast.Ident:
if fun.Obj == nil {
return false, nil
}

funcDecl, ok := fun.Obj.Decl.(*ast.FuncDecl)
if !ok {
return false, nil
if !safeFunc(fun) {
return false, "goroutine must have recover"
}

return safeFunc(funcDecl, pass)
return true, ""
}

return false, fmt.Errorf("unexpected goroutine function type: %v", reflect.TypeOf(fn.Fun).String())
return false, "use function literals when using goroutines"
}

func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) {
func safeFunc(node ast.Node) bool {
result := false
var err error
ast.Inspect(node, func(node ast.Node) bool {
deferStmt, ok := node.(*ast.DeferStmt)
if !ok {
return true
}

ok, err = hasRecover(deferStmt.Call, pass)
if err != nil {
return false
}

ok = hasRecover(deferStmt.Call)
if ok {
result = true
return false
Expand All @@ -110,12 +86,11 @@ func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) {
return !result
})

return result, err
return result
}

func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) {
func hasRecover(expr ast.Node) bool {
var result bool
var err error
ast.Inspect(expr, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.CallExpr:
Expand All @@ -128,69 +103,15 @@ func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) {
return true
}

var ok bool
ok, err = safeSelectorExpr(n, pass, hasRecover)
if err != nil {
return false
}

if ok || n.Sel.Name == customRecover {
if n.Sel.Name == customRecover {
result = true
return false
}
}
return true
})

return result, err
}

func safeSelectorExpr(
expr *ast.SelectorExpr,
pass *analysis.Pass,
methodChecker func(node ast.Node, pass *analysis.Pass) (bool, error),
) (bool, error) {
ident, ok := expr.X.(*ast.Ident)
if !ok {
return false, nil
}

methodName := expr.Sel.Name
objType := pass.TypesInfo.ObjectOf(ident)
pointerType, ok := objType.Type().(*types.Pointer)
if !ok {
return false, nil
}

named, ok := pointerType.Elem().(*types.Named)
if !ok {
return false, nil
}

result := false
for i := 0; i < named.NumMethods(); i++ {
if named.Method(i).Name() != methodName {
continue
}

fset := pass.Fset
position := fset.Position(named.Method(i).Pos())
file, err := parser.ParseFile(fset, position.Filename, nil, 0)
if err != nil {
return false, fmt.Errorf("parse file: %w", err)
}

for _, decl := range file.Decls {
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
if funcDecl.Name.Name == methodName {
result, err = methodChecker(funcDecl, pass)
break
}
}
}
}

return result, nil
return result
}

func isRecover(callExpr *ast.CallExpr) bool {
Expand All @@ -199,7 +120,7 @@ func isRecover(callExpr *ast.CallExpr) bool {
return false
}

return ident.Name == "recover"
return ident.Name == "recover" || ident.Name == customRecover
}

func isCustomRecover(callExpr *ast.CallExpr) bool {
Expand Down
1 change: 1 addition & 0 deletions test/src/custom/recover.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package custom
54 changes: 4 additions & 50 deletions test/src/succdata/succcode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package succdata

func whenASTFuncLit() {
go func() {
defer recover()
}()

go func() {
defer func() {
if r := recover(); r != nil {
Expand All @@ -23,54 +27,4 @@ func whenASTFuncLit() {

defer rec()
}()

go func() {
defer customRecover()
}()

}

func whenIdent() {
go runGoroutine()
go nestedFunc1()
}

func whenCallMethod() {
foo := &Foo{}
go foo.run()
go func() {
defer foo.Recover()
}()
}

func runGoroutine() {
defer func() {
recover()
}()
}

func nestedFunc1() {
// must have recover in parent caller
nestedFunc2()
defer func() {
recover()
}()
}

func nestedFunc2() {}

func customRecover() {
recover()
}

type Foo struct{}

func (a *Foo) run() {
defer func() {
recover()
}()
}

func (a *Foo) Recover() {
recover()
}

0 comments on commit 99e488c

Please sign in to comment.