2525 appendCode []string // Code to append at the end.
2626 returnVars []string // Return variables to modify.
2727 appendSwitch functionSwitches // Switch cases to append.
28+ removeCalls []string // Function calls to remove.
2829 }
2930
3031 // FunctionOptions configures code generation.
@@ -219,6 +220,15 @@ func AppendSwitchCase(condition, switchCase, switchBody string) FunctionOptions
219220 }
220221}
221222
223+ // RemoveFuncCall removes function calls with the specified name from within a function.
224+ // The callName can be either a simple function name like "doSomething" or a qualified
225+ // name like "pkg.DoSomething".
226+ func RemoveFuncCall (callName string ) FunctionOptions {
227+ return func (c * functionOpts ) {
228+ c .removeCalls = append (c .removeCalls , callName )
229+ }
230+ }
231+
222232// newFunctionOptions creates a new functionOpts with defaults.
223233func newFunctionOptions () functionOpts {
224234 return functionOpts {
@@ -230,6 +240,7 @@ func newFunctionOptions() functionOpts {
230240 appendTestCase : make ([]string , 0 ),
231241 appendCode : make ([]string , 0 ),
232242 returnVars : make ([]string , 0 ),
243+ removeCalls : make ([]string , 0 ),
233244 }
234245}
235246
@@ -635,6 +646,13 @@ func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functio
635646 switchesCasesMapCheck = opts .appendSwitch .Map ()
636647 )
637648
649+ // Remove function calls if specified.
650+ if len (opts .removeCalls ) > 0 {
651+ if err := removeFunctionCalls (f , opts .removeCalls ); err != nil {
652+ return err
653+ }
654+ }
655+
638656 // Apply all modifications.
639657 var errInspect error
640658 ast .Inspect (f , func (n ast.Node ) bool {
@@ -920,3 +938,146 @@ func ModifyCaller(content, callerExpr string, modifiers func([]string) ([]string
920938
921939 return string (result ), nil
922940}
941+
942+ // RemoveFunction removes a function declaration from the file content.
943+ func RemoveFunction (content , funcName string ) (string , error ) {
944+ // Parse source into AST.
945+ fset := token .NewFileSet ()
946+ file , err := parser .ParseFile (fset , "" , content , parser .ParseComments )
947+ if err != nil {
948+ return "" , errors .Errorf ("failed to parse file: %w" , err )
949+ }
950+
951+ cmap := ast .NewCommentMap (fset , file , file .Comments )
952+
953+ // Find the function to remove.
954+ var found bool
955+ var newDecls []ast.Decl
956+ for _ , decl := range file .Decls {
957+ if fd , ok := decl .(* ast.FuncDecl ); ok && fd .Name .Name == funcName {
958+ found = true
959+ // Remove comments associated with this function.
960+ delete (cmap , decl )
961+ continue // Skip this declaration to remove it.
962+ }
963+ newDecls = append (newDecls , decl )
964+ }
965+
966+ if ! found {
967+ return "" , errors .Errorf ("function %q not found" , funcName )
968+ }
969+
970+ // Update file declarations and comments.
971+ file .Decls = newDecls
972+ file .Comments = cmap .Filter (file ).Comments ()
973+
974+ return formatNode (fset , file )
975+ }
976+
977+ // removeFunctionCalls removes all function calls matching the specified names from a function.
978+ func removeFunctionCalls (f * ast.FuncDecl , callNames []string ) error {
979+ if f .Body == nil {
980+ return nil
981+ }
982+
983+ // Create a map for faster lookup.
984+ callMap := make (map [string ]bool )
985+ for _ , name := range callNames {
986+ callMap [name ] = true
987+ }
988+
989+ // Helper to check if a call expression matches any of the names to remove.
990+ matchesCall := func (callExpr * ast.CallExpr ) bool {
991+ switch fun := callExpr .Fun .(type ) {
992+ case * ast.Ident :
993+ // Simple function call like doSomething().
994+ return callMap [fun .Name ]
995+ case * ast.SelectorExpr :
996+ // Qualified function call like pkg.DoSomething().
997+ if ident , ok := fun .X .(* ast.Ident ); ok {
998+ qualified := ident .Name + "." + fun .Sel .Name
999+ return callMap [qualified ]
1000+ }
1001+ }
1002+ return false
1003+ }
1004+
1005+ // Filter statements to remove matching function calls.
1006+ var filterStmts func ([]ast.Stmt ) []ast.Stmt
1007+ filterStmts = func (stmts []ast.Stmt ) []ast.Stmt {
1008+ var filtered []ast.Stmt
1009+ for _ , stmt := range stmts {
1010+ keep := true
1011+
1012+ // Check if this is an expression statement with a call expression.
1013+ if exprStmt , ok := stmt .(* ast.ExprStmt ); ok {
1014+ if callExpr , ok := exprStmt .X .(* ast.CallExpr ); ok {
1015+ if matchesCall (callExpr ) {
1016+ keep = false
1017+ }
1018+ }
1019+ }
1020+
1021+ // Recursively handle block statements.
1022+ if blockStmt , ok := stmt .(* ast.BlockStmt ); ok {
1023+ blockStmt .List = filterStmts (blockStmt .List )
1024+ }
1025+
1026+ // Recursively handle if statements.
1027+ if ifStmt , ok := stmt .(* ast.IfStmt ); ok {
1028+ if ifStmt .Body != nil {
1029+ ifStmt .Body .List = filterStmts (ifStmt .Body .List )
1030+ }
1031+ if ifStmt .Else != nil {
1032+ if elseBlock , ok := ifStmt .Else .(* ast.BlockStmt ); ok {
1033+ elseBlock .List = filterStmts (elseBlock .List )
1034+ }
1035+ }
1036+ }
1037+
1038+ // Recursively handle for statements.
1039+ if forStmt , ok := stmt .(* ast.ForStmt ); ok {
1040+ if forStmt .Body != nil {
1041+ forStmt .Body .List = filterStmts (forStmt .Body .List )
1042+ }
1043+ }
1044+
1045+ // Recursively handle range statements.
1046+ if rangeStmt , ok := stmt .(* ast.RangeStmt ); ok {
1047+ if rangeStmt .Body != nil {
1048+ rangeStmt .Body .List = filterStmts (rangeStmt .Body .List )
1049+ }
1050+ }
1051+
1052+ // Recursively handle switch statements.
1053+ if switchStmt , ok := stmt .(* ast.SwitchStmt ); ok {
1054+ if switchStmt .Body != nil {
1055+ for _ , caseClause := range switchStmt .Body .List {
1056+ if cc , ok := caseClause .(* ast.CaseClause ); ok {
1057+ cc .Body = filterStmts (cc .Body )
1058+ }
1059+ }
1060+ }
1061+ }
1062+
1063+ // Recursively handle type switch statements.
1064+ if typeSwitchStmt , ok := stmt .(* ast.TypeSwitchStmt ); ok {
1065+ if typeSwitchStmt .Body != nil {
1066+ for _ , caseClause := range typeSwitchStmt .Body .List {
1067+ if cc , ok := caseClause .(* ast.CaseClause ); ok {
1068+ cc .Body = filterStmts (cc .Body )
1069+ }
1070+ }
1071+ }
1072+ }
1073+
1074+ if keep {
1075+ filtered = append (filtered , stmt )
1076+ }
1077+ }
1078+ return filtered
1079+ }
1080+
1081+ f .Body .List = filterStmts (f .Body .List )
1082+ return nil
1083+ }
0 commit comments