diff --git a/demo/basic/main.go b/demo/basic/main.go index c08f00e2..d5ff665b 100644 --- a/demo/basic/main.go +++ b/demo/basic/main.go @@ -31,6 +31,20 @@ type MyStruct struct{} func (m *MyStruct) Example() { println("MyStruct.Example") } +type GenStruct[T any] struct { + Value T +} + +func (m *GenStruct[T]) GenericRecvExample(t T) T { + fmt.Printf("%s%s\n", m.Value, t) + return t +} + +func GenericExample[K comparable, V any](key K, value V) V { + println("Hello, Generic World!", key, value) + return value +} + // Example demonstrates how to use the instrumenter. func Example() { // Output: @@ -58,6 +72,10 @@ func main() { m.NewField = "abc" m.Example() + _ = GenericExample(1, 2) + g := &GenStruct[string]{Value: "Hello"} + _ = g.GenericRecvExample(", Generic Recv World!") + // Call real module function println(rate.Every(time.Duration(1))) } diff --git a/pkg/instrumentation/helloworld/helloworld_hook.go b/pkg/instrumentation/helloworld/helloworld_hook.go index 8a379adf..a2a91367 100644 --- a/pkg/instrumentation/helloworld/helloworld_hook.go +++ b/pkg/instrumentation/helloworld/helloworld_hook.go @@ -71,4 +71,32 @@ func MyHook1After(ictx inst.HookContext) { println("After MyStruct.Example()") } +func MyHookRecvBefore(ictx inst.HookContext, recv, _ interface{}) { + println("GenericRecvExample before hook") +} + +func MyHookRecvAfter(ictx inst.HookContext, _ interface{}) { + println("GenericRecvExample after hook") +} + +func MyHookGenericBefore(ictx inst.HookContext, _, _ interface{}) { + println("GenericExample before hook") + fmt.Printf("[Generic] Function: %s.%s\n", ictx.GetPackageName(), ictx.GetFuncName()) + fmt.Printf("[Generic] Param count: %d\n", ictx.GetParamCount()) + fmt.Printf("[Generic] Skip call: %v\n", ictx.IsSkipCall()) + for i := 0; i < ictx.GetParamCount(); i++ { + fmt.Printf("[Generic] Param[%d]: %v\n", i, *ictx.GetParam(i).(*int)) + } + ictx.SetData("test-data") +} + +func MyHookGenericAfter(ictx inst.HookContext, _ interface{}) { + println("GenericExample after hook") + fmt.Printf("[Generic] Data from Before: %v\n", ictx.GetData()) + fmt.Printf("[Generic] Return value count: %d\n", ictx.GetReturnValCount()) + for i := 0; i < ictx.GetReturnValCount(); i++ { + fmt.Printf("[Generic] Return[%d]: %v\n", i, *ictx.GetReturnVal(i).(*int)) + } +} + func BeforeUnderscore(ictx inst.HookContext, _ int, _ float32) {} diff --git a/test/integration/basic_test.go b/test/integration/basic_test.go index e6aed15b..fb21660e 100644 --- a/test/integration/basic_test.go +++ b/test/integration/basic_test.go @@ -23,6 +23,13 @@ func TestBasic(t *testing.T) { "Every1", "Every3", "MyStruct.Example", + "GenericExample before hook", + "Hello, Generic World! 1 2", + "GenericExample after hook", + "traceID: 123, spanID: 456", + "GenericRecvExample before hook", + "Hello, Generic Recv World!", + "GenericRecvExample after hook", "traceID: 123, spanID: 456", "[MyHook]", "=setupOpenTelemetry=", @@ -31,4 +38,22 @@ func TestBasic(t *testing.T) { for _, e := range expect { require.Contains(t, output, e) } + + verifyGenericHookContextLogs(t, output) +} + +func verifyGenericHookContextLogs(t *testing.T, output string) { + expectedGenericLogs := []string{ + "[Generic] Function: main.GenericExample", + "[Generic] Param count: 2", + "[Generic] Skip call: false", + "[Generic] Param[0]: 1", + "[Generic] Param[1]: 2", + "[Generic] Data from Before: test-data", + "[Generic] Return value count: 1", + "[Generic] Return[0]: 2", + } + for _, log := range expectedGenericLogs { + require.Contains(t, output, log, "Expected generic HookContext log: %s", log) + } } diff --git a/tool/data/helloworld.yaml b/tool/data/helloworld.yaml index 6405a2a9..519e04cf 100644 --- a/tool/data/helloworld.yaml +++ b/tool/data/helloworld.yaml @@ -57,3 +57,18 @@ underscore_param: func: Underscore before: BeforeUnderscore path: "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/helloworld" + +hook_generic: + target: main + func: GenericExample + before: MyHookGenericBefore + after: MyHookGenericAfter + path: "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/helloworld" + +hook_generic_recv: + target: main + func: GenericRecvExample + recv: "*GenStruct" + before: MyHookRecvBefore + after: MyHookRecvAfter + path: "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/helloworld" diff --git a/tool/internal/ast/primitives.go b/tool/internal/ast/primitives.go index 9bc4ecf7..b514eeb6 100644 --- a/tool/internal/ast/primitives.go +++ b/tool/internal/ast/primitives.go @@ -43,9 +43,31 @@ func AddressOf(name string) *dst.UnaryExpr { return &dst.UnaryExpr{Op: token.AND, X: Ident(name)} } -func CallTo(name string, args []dst.Expr) *dst.CallExpr { +// CallTo creates a call expression to a function with optional type arguments for generics. +// For non-generic functions (typeArgs is nil or empty), creates a simple call: Foo(args...) +// For generic functions with type arguments, creates: Foo[T1, T2](args...) +func CallTo(name string, typeArgs *dst.FieldList, args []dst.Expr) *dst.CallExpr { + if typeArgs == nil || len(typeArgs.List) == 0 { + return &dst.CallExpr{ + Fun: &dst.Ident{Name: name}, + Args: args, + } + } + + var indices []dst.Expr + for _, field := range typeArgs.List { + for _, ident := range field.Names { + indices = append(indices, Ident(ident.Name)) + } + } + var fun dst.Expr + if len(indices) == 1 { + fun = IndexExpr(Ident(name), indices[0]) + } else { + fun = IndexListExpr(Ident(name), indices) + } return &dst.CallExpr{ - Fun: &dst.Ident{Name: name}, + Fun: fun, Args: args, } } @@ -103,6 +125,14 @@ func IndexExpr(x, index dst.Expr) *dst.IndexExpr { } } +func IndexListExpr(x dst.Expr, indices []dst.Expr) *dst.IndexListExpr { + e := util.AssertType[dst.Expr](dst.Clone(x)) + return &dst.IndexListExpr{ + X: e, + Indices: indices, + } +} + func TypeAssertExpr(x, t dst.Expr) *dst.TypeAssertExpr { e := util.AssertType[dst.Expr](dst.Clone(t)) return &dst.TypeAssertExpr{ @@ -275,3 +305,14 @@ func StructLit(typeName string, fields ...*dst.KeyValueExpr) dst.Expr { X: CompositeLit(Ident(typeName), exprs), } } + +// CloneTypeParams safely clones a type parameter field list for generic functions. +// Returns nil if the input is nil. +func CloneTypeParams(typeParams *dst.FieldList) *dst.FieldList { + if typeParams == nil { + return nil + } + cloned, ok := dst.Clone(typeParams).(*dst.FieldList) + util.Assert(ok, "typeParams is not a FieldList") + return cloned +} diff --git a/tool/internal/ast/primitives_test.go b/tool/internal/ast/primitives_test.go new file mode 100644 index 00000000..6a76222c --- /dev/null +++ b/tool/internal/ast/primitives_test.go @@ -0,0 +1,295 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package ast + +import ( + "testing" + + "github.com/dave/dst" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func assertSimpleCall(t *testing.T, expr *dst.CallExpr, expectedFuncName string, expectedArgCount int) { + funcIdent, _ := expr.Fun.(*dst.Ident) + assert.Equal(t, expectedFuncName, funcIdent.Name) + assert.Len(t, expr.Args, expectedArgCount) +} + +func assertIndexExprCall( + t *testing.T, + expr *dst.CallExpr, + expectedFuncName string, + expectedTypeParam string, + expectedArgCount int, +) { + indexExpr, _ := expr.Fun.(*dst.IndexExpr) + funcIdent, _ := indexExpr.X.(*dst.Ident) + assert.Equal(t, expectedFuncName, funcIdent.Name) + typeIdent, _ := indexExpr.Index.(*dst.Ident) + assert.Equal(t, expectedTypeParam, typeIdent.Name) + assert.Len(t, expr.Args, expectedArgCount) +} + +func assertIndexListExprCall( + t *testing.T, + expr *dst.CallExpr, + expectedFuncName string, + expectedTypeParams []string, + expectedArgCount int, +) { + indexListExpr, _ := expr.Fun.(*dst.IndexListExpr) + funcIdent, _ := indexListExpr.X.(*dst.Ident) + assert.Equal(t, expectedFuncName, funcIdent.Name) + require.Len(t, indexListExpr.Indices, len(expectedTypeParams)) + for i, expectedParam := range expectedTypeParams { + paramIdent, _ := indexListExpr.Indices[i].(*dst.Ident) + assert.Equal(t, expectedParam, paramIdent.Name) + } + assert.Len(t, expr.Args, expectedArgCount) +} + +// Helper function to parse a complete function and extract its type parameters +func parseFuncTypeParams(t *testing.T, funcSource string) *dst.FieldList { + parser := NewAstParser() + file, err := parser.ParseSource("package main\n" + funcSource) + require.NoError(t, err) + require.Len(t, file.Decls, 1) + funcDecl, ok := file.Decls[0].(*dst.FuncDecl) + require.True(t, ok) + return funcDecl.Type.TypeParams +} + +func TestCallTo(t *testing.T) { + tests := []struct { + name string + funcName string + funcSource string // Source code for parsing type params + args []dst.Expr + validate func(*testing.T, *dst.CallExpr) + }{ + { + name: "nil type params returns simple call", + funcName: "Foo", + funcSource: "func Foo(x, y int) {}", // No type params + args: []dst.Expr{Ident("x"), Ident("y")}, + validate: func(t *testing.T, expr *dst.CallExpr) { + assertSimpleCall(t, expr, "Foo", 2) + }, + }, + { + name: "single type parameter creates IndexExpr", + funcName: "GenericFunc", + funcSource: "func GenericFunc[T any](value T) {}", + args: []dst.Expr{Ident("value")}, + validate: func(t *testing.T, expr *dst.CallExpr) { + assertIndexExprCall(t, expr, "GenericFunc", "T", 1) + }, + }, + { + name: "multiple type parameters creates IndexListExpr", + funcName: "MultiGeneric", + funcSource: "func MultiGeneric[T any, U comparable](x T, y U) {}", + args: []dst.Expr{Ident("x"), Ident("y")}, + validate: func(t *testing.T, expr *dst.CallExpr) { + assertIndexListExprCall(t, expr, "MultiGeneric", []string{"T", "U"}, 2) + }, + }, + { + name: "field with multiple names creates multiple indices", + funcName: "MultiNameGeneric", + funcSource: "func MultiNameGeneric[T, U any](value T) {}", + args: []dst.Expr{Ident("value")}, + validate: func(t *testing.T, expr *dst.CallExpr) { + assertIndexListExprCall(t, expr, "MultiNameGeneric", []string{"T", "U"}, 1) + }, + }, + { + name: "no arguments with type parameters", + funcName: "NoArgsGeneric", + funcSource: "func NoArgsGeneric[T any]() {}", + args: []dst.Expr{}, + validate: func(t *testing.T, expr *dst.CallExpr) { + assertIndexExprCall(t, expr, "NoArgsGeneric", "T", 0) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + typeParams := parseFuncTypeParams(t, tt.funcSource) + result := CallTo(tt.funcName, typeParams, tt.args) + require.NotNil(t, result) + tt.validate(t, result) + }) + } +} + +func TestCloneTypeParams(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + assert.Nil(t, CloneTypeParams(nil)) + }) + + t.Run("clones are independent instances with same content", func(t *testing.T) { + testCases := []struct { + name string + original *dst.FieldList + }{ + { + name: "single type parameter", + original: &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("T")}, Type: Ident("any")}, + }, + }, + }, + { + name: "multiple type parameters", + original: &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("T")}, Type: Ident("any")}, + {Names: []*dst.Ident{Ident("U")}, Type: Ident("comparable")}, + }, + }, + }, + { + name: "field with multiple names", + original: &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("T"), Ident("U")}, Type: Ident("any")}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cloned := CloneTypeParams(tc.original) + assert.NotSame(t, tc.original, cloned) + assert.Equal(t, tc.original, cloned) + }) + } + }) + + t.Run("modifications to clone don't affect original", func(t *testing.T) { + original := &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("T")}, Type: Ident("any")}, + }, + } + cloned := CloneTypeParams(original) + + cloned.List[0].Names[0].Name = "Modified" + + assert.Equal(t, "T", original.List[0].Names[0].Name) + assert.Equal(t, "Modified", cloned.List[0].Names[0].Name) + }) +} + +func TestSplitMultiNameFields(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + assert.Nil(t, SplitMultiNameFields(nil)) + }) + + t.Run("empty field list returns empty list", func(t *testing.T) { + input := &dst.FieldList{List: []*dst.Field{}} + result := SplitMultiNameFields(input) + assert.NotNil(t, result) + assert.Empty(t, result.List) + }) + + t.Run("single name fields remain unchanged", func(t *testing.T) { + input := &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("a")}, Type: Ident("int")}, + {Names: []*dst.Ident{Ident("b")}, Type: Ident("string")}, + }, + } + result := SplitMultiNameFields(input) + require.Len(t, result.List, 2) + assert.Equal(t, "a", result.List[0].Names[0].Name) + assert.Equal(t, "int", result.List[0].Type.(*dst.Ident).Name) + assert.Equal(t, "b", result.List[1].Names[0].Name) + assert.Equal(t, "string", result.List[1].Type.(*dst.Ident).Name) + }) + + t.Run("multi-name field is split into separate fields", func(t *testing.T) { + input := &dst.FieldList{ + List: []*dst.Field{ + { + Names: []*dst.Ident{Ident("a"), Ident("b")}, + Type: Ident("int"), + }, + }, + } + result := SplitMultiNameFields(input) + require.Len(t, result.List, 2) + assert.Equal(t, "a", result.List[0].Names[0].Name) + assert.Equal(t, "int", result.List[0].Type.(*dst.Ident).Name) + assert.Equal(t, "b", result.List[1].Names[0].Name) + assert.Equal(t, "int", result.List[1].Type.(*dst.Ident).Name) + }) + + t.Run("underscore parameters are properly split", func(t *testing.T) { + input := &dst.FieldList{ + List: []*dst.Field{ + { + Names: []*dst.Ident{Ident("_"), Ident("_")}, + Type: InterfaceType(), + }, + }, + } + result := SplitMultiNameFields(input) + require.Len(t, result.List, 2) + assert.Equal(t, "_", result.List[0].Names[0].Name) + assert.NotNil(t, result.List[0].Type.(*dst.InterfaceType)) + assert.Equal(t, "_", result.List[1].Names[0].Name) + assert.NotNil(t, result.List[1].Type.(*dst.InterfaceType)) + }) + + t.Run("mixed single and multi-name fields", func(t *testing.T) { + input := &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("a")}, Type: Ident("int")}, + {Names: []*dst.Ident{Ident("b"), Ident("c")}, Type: Ident("string")}, + {Names: []*dst.Ident{Ident("d")}, Type: Ident("bool")}, + }, + } + result := SplitMultiNameFields(input) + require.Len(t, result.List, 4) + assert.Equal(t, "a", result.List[0].Names[0].Name) + assert.Equal(t, "int", result.List[0].Type.(*dst.Ident).Name) + assert.Equal(t, "b", result.List[1].Names[0].Name) + assert.Equal(t, "string", result.List[1].Type.(*dst.Ident).Name) + assert.Equal(t, "c", result.List[2].Names[0].Name) + assert.Equal(t, "string", result.List[2].Type.(*dst.Ident).Name) + assert.Equal(t, "d", result.List[3].Names[0].Name) + assert.Equal(t, "bool", result.List[3].Type.(*dst.Ident).Name) + }) + + t.Run("unnamed field remains unchanged", func(t *testing.T) { + input := &dst.FieldList{ + List: []*dst.Field{ + {Names: nil, Type: Ident("int")}, + }, + } + result := SplitMultiNameFields(input) + require.Len(t, result.List, 1) + assert.Nil(t, result.List[0].Names) + assert.Equal(t, "int", result.List[0].Type.(*dst.Ident).Name) + }) + + t.Run("modifications to result don't affect original", func(t *testing.T) { + original := &dst.FieldList{ + List: []*dst.Field{ + {Names: []*dst.Ident{Ident("a"), Ident("b")}, Type: Ident("int")}, + }, + } + result := SplitMultiNameFields(original) + + result.List[0].Names[0].Name = "Modified" + + assert.Equal(t, "a", original.List[0].Names[0].Name) + assert.Equal(t, "Modified", result.List[0].Names[0].Name) + }) +} diff --git a/tool/internal/ast/shared.go b/tool/internal/ast/shared.go index 71ed8aab..b6faf30a 100644 --- a/tool/internal/ast/shared.go +++ b/tool/internal/ast/shared.go @@ -44,6 +44,48 @@ func FindFuncDeclWithoutRecv(root *dst.File, funcName string) *dst.FuncDecl { return decls[0] } +// extractBaseReceiverType extracts the base type name from a receiver expression, +// handling both generic and non-generic types. +// For example: +// - *MyStruct -> *MyStruct +// - MyStruct -> MyStruct +// - *GenStruct[T] -> *GenStruct +// - GenStruct[T] -> GenStruct +func extractBaseReceiverType(recvTypeExpr dst.Expr) string { + switch expr := recvTypeExpr.(type) { + case *dst.StarExpr: // func (*Recv)T or func (*Recv[T])T + // Check if X is an Ident (non-generic) or IndexExpr/IndexListExpr (generic) + switch x := expr.X.(type) { + case *dst.Ident: + // Non-generic pointer receiver: *MyStruct + return "*" + x.Name + case *dst.IndexExpr: + // Generic pointer receiver with single type param: *GenStruct[T] + if baseIdent, ok := x.X.(*dst.Ident); ok { + return "*" + baseIdent.Name + } + case *dst.IndexListExpr: + // Generic pointer receiver with multiple type params: *GenStruct[T, U] + if baseIdent, ok := x.X.(*dst.Ident); ok { + return "*" + baseIdent.Name + } + } + case *dst.Ident: // func (Recv)T + return expr.Name + case *dst.IndexExpr: + // Generic value receiver with single type param: GenStruct[T] + if baseIdent, ok := expr.X.(*dst.Ident); ok { + return baseIdent.Name + } + case *dst.IndexListExpr: + // Generic value receiver with multiple type params: GenStruct[T, U] + if baseIdent, ok := expr.X.(*dst.Ident); ok { + return baseIdent.Name + } + } + return "" +} + func FindFuncDecl(root *dst.File, funcName, recv string) *dst.FuncDecl { decls := findFuncDecls(root, func(funcDecl *dst.FuncDecl) bool { // Receiver type is ignored, match func name only @@ -59,26 +101,15 @@ func FindFuncDecl(root *dst.File, funcName, recv string) *dst.FuncDecl { // Receiver type is specified, and target function has receiver // Match both func name and receiver type - switch recvTypeExpr := funcDecl.Recv.List[0].Type.(type) { - case *dst.StarExpr: // func (*Recv)T - tn, ok := recvTypeExpr.X.(*dst.Ident) - if !ok { - // This is a generic type, we don't support it yet - return false - } - t := "*" + tn.Name - return t == recv && name == funcName - case *dst.Ident: // func (Recv)T - t := recvTypeExpr.Name - return t == recv && name == funcName - case *dst.IndexExpr: - // This is a generic type, we don't support it yet - return false - default: + recvTypeExpr := funcDecl.Recv.List[0].Type + baseType := extractBaseReceiverType(recvTypeExpr) + + if baseType == "" { msg := fmt.Sprintf("unexpected receiver type: %T", recvTypeExpr) util.Unimplemented(msg) } - return false + + return baseType == recv && name == funcName }) if len(decls) == 0 { @@ -154,3 +185,39 @@ func AddStructField(decl dst.Decl, name, t string) { st := util.AssertType[*dst.StructType](ty.Type) st.Fields.List = append(st.Fields.List, fd) } + +// SplitMultiNameFields splits fields that have multiple names into separate fields. +// For example, a field like "a, b int" becomes two fields: "a int" and "b int". +func SplitMultiNameFields(fieldList *dst.FieldList) *dst.FieldList { + if fieldList == nil { + return nil + } + result := &dst.FieldList{List: []*dst.Field{}} + for _, field := range fieldList.List { + // Handle unnamed fields (e.g., embedded types) or fields with single/multiple names + namesToProcess := field.Names + if len(namesToProcess) == 0 { + // For unnamed fields, create one field with no names + namesToProcess = []*dst.Ident{nil} + } + + for _, name := range namesToProcess { + clonedType, ok := dst.Clone(field.Type).(dst.Expr) + util.Assert(ok, "field.Type is not an Expr") + + var names []*dst.Ident + if name != nil { + clonedName, okC := dst.Clone(name).(*dst.Ident) + util.Assert(okC, "name is not an Ident") + names = []*dst.Ident{clonedName} + } + + newField := &dst.Field{ + Names: names, + Type: clonedType, + } + result.List = append(result.List, newField) + } + } + return result +} diff --git a/tool/internal/instrument/apply_func.go b/tool/internal/instrument/apply_func.go index 66fbc087..72a9d85e 100644 --- a/tool/internal/instrument/apply_func.go +++ b/tool/internal/instrument/apply_func.go @@ -113,8 +113,8 @@ func createTJumpIf(t *rule.InstFuncRule, funcDecl *dst.FuncDecl, argsToAfter := createHookArgs(retVals) argHookContext := ast.Ident(trampolineHookContextName + funcSuffix) argsToAfter = append([]dst.Expr{argHookContext}, argsToAfter...) - beforeCall := ast.CallTo(makeName(t, funcDecl, true), argsToBefore) - afterCall := ast.CallTo(makeName(t, funcDecl, false), argsToAfter) + beforeCall := ast.CallTo(makeName(t, funcDecl, true), funcDecl.Type.TypeParams, argsToBefore) + afterCall := ast.CallTo(makeName(t, funcDecl, false), funcDecl.Type.TypeParams, argsToAfter) tjumpInit := ast.DefineStmts( ast.Exprs( ast.Ident(trampolineHookContextName+funcSuffix), diff --git a/tool/internal/instrument/trampoline.go b/tool/internal/instrument/trampoline.go index be82ec1d..83f5ae7a 100644 --- a/tool/internal/instrument/trampoline.go +++ b/tool/internal/instrument/trampoline.go @@ -233,8 +233,9 @@ func getHookParamTraits(t *rule.InstFuncRule, before bool) ([]ParamTrait, error) return nil, err } attrs := make([]ParamTrait, 0) + splitParams := ast.SplitMultiNameFields(target.Type.Params) // Find which parameter is type of interface{} - for i, field := range target.Type.Params.List { + for i, field := range splitParams.List { attr := ParamTrait{Index: i} if ast.IsInterfaceType(field.Type) { attr.IsInterfaceAny = true @@ -268,7 +269,7 @@ func (ip *InstrumentPhase) callBeforeHook(t *rule.InstFuncRule, traits []ParamTr } } fnName := makeOnXName(t, true) - call := ast.ExprStmt(ast.CallTo(fnName, args)) + call := ast.ExprStmt(ast.CallTo(fnName, nil, args)) iff := ast.IfNotNilStmt( ast.Ident(fnName), ast.Block(call), @@ -305,7 +306,7 @@ func (ip *InstrumentPhase) callAfterHook(t *rule.InstFuncRule, traits []ParamTra } } fnName := makeOnXName(t, false) - call := ast.ExprStmt(ast.CallTo(fnName, args)) + call := ast.ExprStmt(ast.CallTo(fnName, nil, args)) iff := ast.IfNotNilStmt( ast.Ident(fnName), ast.Block(call), @@ -315,31 +316,27 @@ func (ip *InstrumentPhase) callAfterHook(t *rule.InstFuncRule, traits []ParamTra return nil } -func rectifyAnyType(paramList *dst.FieldList, traits []ParamTrait) error { - if len(paramList.List) != len(traits) { +func (ip *InstrumentPhase) addHookFuncVar(t *rule.InstFuncRule, + traits []ParamTrait, before bool, +) error { + paramTypes := ip.buildTrampolineType(before) + addHookContext(paramTypes) + combinedTypeParams := combineTypeParams(ip.targetFunc) + + if len(paramTypes.List) != len(traits) { return ex.New("hook func signature can not match with target function") } - for i, field := range paramList.List { + + for i, field := range paramTypes.List { trait := traits[i] if trait.IsInterfaceAny { - // Rectify type to "interface{}" + // Hook explicitly uses interface{} for this parameter field.Type = ast.InterfaceType() + } else { + // Replace type parameters with interface{} (for linkname compatibility) + field.Type = replaceTypeParamsWithAny(field.Type, combinedTypeParams) } } - return nil -} - -func (ip *InstrumentPhase) addHookFuncVar(t *rule.InstFuncRule, - traits []ParamTrait, before bool, -) error { - paramTypes := ip.buildTrampolineType(before) - addHookContext(paramTypes) - // Hook functions may uses interface{} as parameter type, as some types of - // raw function is not exposed - err := rectifyAnyType(paramTypes, traits) - if err != nil { - return err - } // Generate var decl and append it to the target file, note that many target // functions may match the same hook function, it's a fatal error to append @@ -426,20 +423,24 @@ func (ip *InstrumentPhase) buildTrampolineType(before bool) *dst.FieldList { // func S(h* HookContext, recv type, arg1 type, arg2 type, ...) // For after trampoline, it's signature is: // func S(h* HookContext, arg1 type, arg2 type, ...) + // All grouped parameters (like a, b int) are expanded into separate parameters (a int, b int) paramList := &dst.FieldList{List: []*dst.Field{}} if before { if ast.HasReceiver(ip.targetFunc) { - recvField := util.AssertType[*dst.Field](dst.Clone(ip.targetFunc.Recv.List[0])) + splitRecv := ast.SplitMultiNameFields(ip.targetFunc.Recv) + recvField := util.AssertType[*dst.Field](dst.Clone(splitRecv.List[0])) renameField(recvField, "recv") paramList.List = append(paramList.List, recvField) } - for _, field := range ip.targetFunc.Type.Params.List { + splitParams := ast.SplitMultiNameFields(ip.targetFunc.Type.Params) + for _, field := range splitParams.List { paramField := util.AssertType[*dst.Field](dst.Clone(field)) renameField(paramField, "param") paramList.List = append(paramList.List, paramField) } } else if ip.targetFunc.Type.Results != nil { - for _, field := range ip.targetFunc.Type.Results.List { + splitResults := ast.SplitMultiNameFields(ip.targetFunc.Type.Results) + for _, field := range splitResults.List { retField := util.AssertType[*dst.Field](dst.Clone(field)) renameField(retField, "arg") paramList.List = append(paramList.List, retField) @@ -464,6 +465,9 @@ func (ip *InstrumentPhase) buildTrampolineTypes() { } } addHookContext(afterHookFunc.Type.Params) + trampolineTypeParams := combineTypeParams(ip.targetFunc) + beforeHookFunc.Type.TypeParams = ast.CloneTypeParams(trampolineTypeParams) + afterHookFunc.Type.TypeParams = ast.CloneTypeParams(trampolineTypeParams) } func assignString(assignStmt *dst.AssignStmt, val string) bool { @@ -629,6 +633,108 @@ func setReturnValClause(idx int, t dst.Expr) *dst.CaseClause { return setValue(trampolineReturnValsIdentifier, idx, t) } +// extractReceiverTypeParams extracts type parameters from a receiver type expression +// For example: *GenStruct[T] or GenStruct[T, U] -> FieldList with T and U as type parameters +func extractReceiverTypeParams(recvType dst.Expr) *dst.FieldList { + switch t := recvType.(type) { + case *dst.StarExpr: + // *GenStruct[T] - recurse into X + return extractReceiverTypeParams(t.X) + case *dst.IndexExpr: + // GenStruct[T] - single type parameter + if ident, ok := t.Index.(*dst.Ident); ok { + return &dst.FieldList{ + List: []*dst.Field{{ + Names: []*dst.Ident{ident}, + Type: ast.Ident("any"), // Type constraint for the parameter + }}, + } + } + case *dst.IndexListExpr: + // GenStruct[T, U, ...] - multiple type parameters + fields := make([]*dst.Field, 0, len(t.Indices)) + for _, idx := range t.Indices { + if ident, ok := idx.(*dst.Ident); ok { + fields = append(fields, &dst.Field{ + Names: []*dst.Ident{ident}, + Type: ast.Ident("any"), // Type constraint for the parameter + }) + } + } + if len(fields) > 0 { + return &dst.FieldList{List: fields} + } + } + return nil +} + +// combineTypeParams combines type parameters from the receiver and function type parameters. +// For methods on generic types, it extracts type parameters from the receiver and merges +// them with the function's type parameters. +// Receiver type parameters come first, followed by function type parameters. +// +// Example: +// +// Original: func (c *Container[K]) Transform[V any]() V +// Result: [K, V] +// +// Generated trampolines: +// func OtelBeforeTrampoline_Container_Transform[K comparable, V any]( +// hookContext *HookContext, +// recv0 *Container[K], // ← Uses K +// ) { ... } +// +// func OtelAfterTrampoline_Container_Transform[K comparable, V any]( +// hookContext *HookContext, +// arg0 *V, // ← Uses V (return type) +// ) { ... } +func combineTypeParams(targetFunc *dst.FuncDecl) *dst.FieldList { + var trampolineTypeParams *dst.FieldList + if ast.HasReceiver(targetFunc) { + receiverTypeParams := extractReceiverTypeParams(targetFunc.Recv.List[0].Type) + if receiverTypeParams != nil { + trampolineTypeParams = receiverTypeParams + } + } + if targetFunc.Type.TypeParams != nil { + if trampolineTypeParams == nil { + trampolineTypeParams = targetFunc.Type.TypeParams + } else { + combined := &dst.FieldList{List: make([]*dst.Field, 0)} + combined.List = append(combined.List, trampolineTypeParams.List...) + combined.List = append(combined.List, targetFunc.Type.TypeParams.List...) + trampolineTypeParams = combined + } + } + return trampolineTypeParams +} + +// replaceGenericInstantiations replaces generic type instantiations with interface{} +// For example: *GenStruct[T] -> *interface{}, []GenStruct[T, U] -> []interface{} +func replaceGenericInstantiations(t dst.Expr) dst.Expr { + switch tType := t.(type) { + case *dst.StarExpr: + // *GenStruct[T] -> *interface{} + return ast.DereferenceOf(replaceGenericInstantiations(tType.X)) + case *dst.ArrayType: + // []GenStruct[T] -> []interface{} + return ast.ArrayType(replaceGenericInstantiations(tType.Elt)) + case *dst.MapType: + // map[K]GenStruct[T] -> map[interface{}]interface{} + return &dst.MapType{ + Key: replaceGenericInstantiations(tType.Key), + Value: replaceGenericInstantiations(tType.Value), + } + case *dst.IndexExpr: + // GenStruct[T] -> interface{} + return ast.InterfaceType() + case *dst.IndexListExpr: + // GenStruct[T, U] -> interface{} + return ast.InterfaceType() + } + return t +} + // desugarType desugars parameter type to its original type, if parameter // is type of ...T, it will be converted to []T func desugarType(param *dst.Field) dst.Expr { @@ -666,39 +772,121 @@ func (ip *InstrumentPhase) rewriteHookContext() { methodGetParamBody := findSwitchBlock(methodGetParam, 0) methodSetRetValBody := findSwitchBlock(methodSetRetVal, 1) methodGetRetValBody := findSwitchBlock(methodGetRetVal, 0) + + combinedTypeParams := combineTypeParams(ip.targetFunc) + + ip.rewriteHookContextParams(methodSetParamBody, methodGetParamBody, combinedTypeParams) + ip.rewriteHookContextResults(methodSetRetValBody, methodGetRetValBody, combinedTypeParams) +} + +func (ip *InstrumentPhase) rewriteHookContextParams( + methodSetParamBody, methodGetParamBody *dst.BlockStmt, + combinedTypeParams *dst.FieldList, +) { idx := 0 if ast.HasReceiver(ip.targetFunc) { - recvType := ip.targetFunc.Recv.List[0].Type + splitRecv := ast.SplitMultiNameFields(ip.targetFunc.Recv) + recvType := replaceGenericInstantiations(splitRecv.List[0].Type) + recvType = replaceTypeParamsWithAny(recvType, combinedTypeParams) clause := setParamClause(idx, recvType) methodSetParamBody.List = append(methodSetParamBody.List, clause) clause = getParamClause(idx, recvType) methodGetParamBody.List = append(methodGetParamBody.List, clause) idx++ } - for _, param := range ip.targetFunc.Type.Params.List { - paramType := desugarType(param) - for range param.Names { - clause := setParamClause(idx, paramType) - methodSetParamBody.List = append(methodSetParamBody.List, clause) - clause = getParamClause(idx, paramType) - methodGetParamBody.List = append(methodGetParamBody.List, clause) + splitParams := ast.SplitMultiNameFields(ip.targetFunc.Type.Params) + for _, param := range splitParams.List { + paramType := replaceTypeParamsWithAny(desugarType(param), combinedTypeParams) + clause := setParamClause(idx, paramType) + methodSetParamBody.List = append(methodSetParamBody.List, clause) + clause = getParamClause(idx, paramType) + methodGetParamBody.List = append(methodGetParamBody.List, clause) + idx++ + } +} + +func (ip *InstrumentPhase) rewriteHookContextResults( + methodSetRetValBody, methodGetRetValBody *dst.BlockStmt, + combinedTypeParams *dst.FieldList, +) { + if ip.targetFunc.Type.Results != nil { + idx := 0 + splitResults := ast.SplitMultiNameFields(ip.targetFunc.Type.Results) + for _, retval := range splitResults.List { + retType := replaceTypeParamsWithAny(desugarType(retval), combinedTypeParams) + clause := getReturnValClause(idx, retType) + methodGetRetValBody.List = append(methodGetRetValBody.List, clause) + clause = setReturnValClause(idx, retType) + methodSetRetValBody.List = append(methodSetRetValBody.List, clause) idx++ } } - // Rewrite GetReturnVal and SetReturnVal methods - if ip.targetFunc.Type.Results != nil { - idx = 0 - for _, retval := range ip.targetFunc.Type.Results.List { - retType := desugarType(retval) - for range retval.Names { - clause := getReturnValClause(idx, retType) - methodGetRetValBody.List = append(methodGetRetValBody.List, clause) - clause = setReturnValClause(idx, retType) - methodSetRetValBody.List = append(methodSetRetValBody.List, clause) - idx++ +} + +// isTypeParameter checks if a type expression is a bare type parameter identifier +func isTypeParameter(t dst.Expr, typeParams *dst.FieldList) bool { + if typeParams == nil { + return false + } + ident, ok := t.(*dst.Ident) + if !ok { + return false + } + // Check if this identifier matches any type parameter name + for _, field := range typeParams.List { + for _, name := range field.Names { + if name.Name == ident.Name { + return true } } } + return false +} + +// replaceTypeParamsWithAny replaces type parameters with interface{} for use in +// non-generic contexts like HookContextImpl methods +func replaceTypeParamsWithAny(t dst.Expr, typeParams *dst.FieldList) dst.Expr { + if isTypeParameter(t, typeParams) { + return ast.InterfaceType() + } + + // For complex types like *T, []T, map[K]V, etc., handle them recursively + switch tType := t.(type) { + case *dst.StarExpr: + // *T -> *interface{} + return ast.DereferenceOf(replaceTypeParamsWithAny(tType.X, typeParams)) + case *dst.ArrayType: + // []T -> []interface{} + return ast.ArrayType(replaceTypeParamsWithAny(tType.Elt, typeParams)) + case *dst.MapType: + // map[K]V -> map[interface{}]interface{} + return &dst.MapType{ + Key: replaceTypeParamsWithAny(tType.Key, typeParams), + Value: replaceTypeParamsWithAny(tType.Value, typeParams), + } + case *dst.ChanType: + // chan T, <-chan T, chan<- T -> chan interface{}, etc. + return &dst.ChanType{ + Dir: tType.Dir, + Value: replaceTypeParamsWithAny(tType.Value, typeParams), + } + case *dst.IndexExpr: + // GenStruct[T] -> interface{} (for generic receiver methods) + // The hook function expects interface{} for generic types + return ast.InterfaceType() + case *dst.IndexListExpr: + // GenStruct[T, U] -> interface{} (for generic receiver methods with multiple type params) + return ast.InterfaceType() + case *dst.Ident, *dst.SelectorExpr, *dst.InterfaceType: + // Base types without type parameters, return as-is + return t + default: + // Unsupported cases: + // - *dst.FuncType (function types with type parameters) + // - Other uncommon type expressions + util.ShouldNotReachHere() + return t + } } func (ip *InstrumentPhase) callHookFunc(t *rule.InstFuncRule, before bool) error {