Skip to content

Commit 63a9420

Browse files
committed
Support recv generic funcs
1 parent 05bac68 commit 63a9420

File tree

6 files changed

+215
-25
lines changed

6 files changed

+215
-25
lines changed

demo/basic/main.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ type MyStruct struct{}
3131

3232
func (m *MyStruct) Example() { println("MyStruct.Example") }
3333

34+
type GenStruct[T any] struct {
35+
Value T
36+
}
37+
38+
func (m *GenStruct[T]) GenericRecvExample(t T) T {
39+
fmt.Printf("%s%s\n", m.Value, t)
40+
return t
41+
}
42+
3443
func GenericExample[K comparable, V any](key K, value V) V {
3544
println("Hello, Generic World!", key, value)
3645
return value
@@ -59,11 +68,14 @@ func main() {
5968
// Call the Example function to trigger the instrumentation
6069
Example()
6170
m := &MyStruct{}
62-
GenericExample(1, 2)
6371
// Add a new field to the struct
6472
m.NewField = "abc"
6573
m.Example()
6674

75+
GenericExample(1, 2)
76+
g := &GenStruct[string]{Value: "Hello"}
77+
_ = g.GenericRecvExample(", Generic Recv World!")
78+
6779
// Call real module function
6880
println(rate.Every(time.Duration(1)))
6981
}

pkg/instrumentation/helloworld/helloworld_hook.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ func MyHook1After(ictx inst.HookContext) {
7171
println("After MyStruct.Example()")
7272
}
7373

74+
func MyHookRecvBefore(ictx inst.HookContext, recv, _ interface{}) {
75+
println("GenericRecvExample before hook")
76+
}
77+
78+
func MyHookRecvAfter(ictx inst.HookContext, _ interface{}) {
79+
println("GenericRecvExample after hook")
80+
}
81+
7482
func MyHookGenericBefore(ictx inst.HookContext, _, _ interface{}) {
7583
println("GenericExample before hook")
7684
}

test/integration/basic_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ func TestBasic(t *testing.T) {
2727
"Hello, Generic World! 1 2",
2828
"GenericExample after hook",
2929
"traceID: 123, spanID: 456",
30+
"GenericRecvExample before hook",
31+
"Hello, Generic Recv World!",
32+
"GenericRecvExample after hook",
33+
"traceID: 123, spanID: 456",
3034
"[MyHook]",
3135
"=setupOpenTelemetry=",
3236
"RawCode",

tool/data/helloworld.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,11 @@ hook_generic:
6464
before: MyHookGenericBefore
6565
after: MyHookGenericAfter
6666
path: "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/helloworld"
67+
68+
hook_generic_recv:
69+
target: main
70+
func: GenericRecvExample
71+
recv: "*GenStruct"
72+
before: MyHookRecvBefore
73+
after: MyHookRecvAfter
74+
path: "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/helloworld"

tool/internal/ast/shared.go

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,48 @@ func FindFuncDeclWithoutRecv(root *dst.File, funcName string) *dst.FuncDecl {
4444
return decls[0]
4545
}
4646

47+
// extractBaseReceiverType extracts the base type name from a receiver expression,
48+
// handling both generic and non-generic types.
49+
// For example:
50+
// - *MyStruct -> *MyStruct
51+
// - MyStruct -> MyStruct
52+
// - *GenStruct[T] -> *GenStruct
53+
// - GenStruct[T] -> GenStruct
54+
func extractBaseReceiverType(recvTypeExpr dst.Expr) string {
55+
switch expr := recvTypeExpr.(type) {
56+
case *dst.StarExpr: // func (*Recv)T or func (*Recv[T])T
57+
// Check if X is an Ident (non-generic) or IndexExpr/IndexListExpr (generic)
58+
switch x := expr.X.(type) {
59+
case *dst.Ident:
60+
// Non-generic pointer receiver: *MyStruct
61+
return "*" + x.Name
62+
case *dst.IndexExpr:
63+
// Generic pointer receiver with single type param: *GenStruct[T]
64+
if baseIdent, ok := x.X.(*dst.Ident); ok {
65+
return "*" + baseIdent.Name
66+
}
67+
case *dst.IndexListExpr:
68+
// Generic pointer receiver with multiple type params: *GenStruct[T, U]
69+
if baseIdent, ok := x.X.(*dst.Ident); ok {
70+
return "*" + baseIdent.Name
71+
}
72+
}
73+
case *dst.Ident: // func (Recv)T
74+
return expr.Name
75+
case *dst.IndexExpr:
76+
// Generic value receiver with single type param: GenStruct[T]
77+
if baseIdent, ok := expr.X.(*dst.Ident); ok {
78+
return baseIdent.Name
79+
}
80+
case *dst.IndexListExpr:
81+
// Generic value receiver with multiple type params: GenStruct[T, U]
82+
if baseIdent, ok := expr.X.(*dst.Ident); ok {
83+
return baseIdent.Name
84+
}
85+
}
86+
return ""
87+
}
88+
4789
func FindFuncDecl(root *dst.File, funcName, recv string) *dst.FuncDecl {
4890
decls := findFuncDecls(root, func(funcDecl *dst.FuncDecl) bool {
4991
// Receiver type is ignored, match func name only
@@ -59,26 +101,15 @@ func FindFuncDecl(root *dst.File, funcName, recv string) *dst.FuncDecl {
59101

60102
// Receiver type is specified, and target function has receiver
61103
// Match both func name and receiver type
62-
switch recvTypeExpr := funcDecl.Recv.List[0].Type.(type) {
63-
case *dst.StarExpr: // func (*Recv)T
64-
tn, ok := recvTypeExpr.X.(*dst.Ident)
65-
if !ok {
66-
// This is a generic type, we don't support it yet
67-
return false
68-
}
69-
t := "*" + tn.Name
70-
return t == recv && name == funcName
71-
case *dst.Ident: // func (Recv)T
72-
t := recvTypeExpr.Name
73-
return t == recv && name == funcName
74-
case *dst.IndexExpr:
75-
// This is a generic type, we don't support it yet
76-
return false
77-
default:
104+
recvTypeExpr := funcDecl.Recv.List[0].Type
105+
baseType := extractBaseReceiverType(recvTypeExpr)
106+
107+
if baseType == "" {
78108
msg := fmt.Sprintf("unexpected receiver type: %T", recvTypeExpr)
79109
util.Unimplemented(msg)
80110
}
81-
return false
111+
112+
return baseType == recv && name == funcName
82113
})
83114

84115
if len(decls) == 0 {

tool/internal/instrument/trampoline.go

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,9 @@ func (ip *InstrumentPhase) buildTrampolineTypes() {
473473
}
474474
}
475475
addHookContext(afterHookFunc.Type.Params)
476-
beforeHookFunc.Type.TypeParams = ast.CloneTypeParams(ip.targetFunc.Type.TypeParams)
477-
afterHookFunc.Type.TypeParams = ast.CloneTypeParams(ip.targetFunc.Type.TypeParams)
476+
trampolineTypeParams := combineTypeParams(ip.targetFunc)
477+
beforeHookFunc.Type.TypeParams = ast.CloneTypeParams(trampolineTypeParams)
478+
afterHookFunc.Type.TypeParams = ast.CloneTypeParams(trampolineTypeParams)
478479
}
479480

480481
func assignString(assignStmt *dst.AssignStmt, val string) bool {
@@ -640,6 +641,108 @@ func setReturnValClause(idx int, t dst.Expr) *dst.CaseClause {
640641
return setValue(trampolineReturnValsIdentifier, idx, t)
641642
}
642643

644+
// extractReceiverTypeParams extracts type parameters from a receiver type expression
645+
// For example: *GenStruct[T] or GenStruct[T, U] -> FieldList with T and U as type parameters
646+
func extractReceiverTypeParams(recvType dst.Expr) *dst.FieldList {
647+
switch t := recvType.(type) {
648+
case *dst.StarExpr:
649+
// *GenStruct[T] - recurse into X
650+
return extractReceiverTypeParams(t.X)
651+
case *dst.IndexExpr:
652+
// GenStruct[T] - single type parameter
653+
if ident, ok := t.Index.(*dst.Ident); ok {
654+
return &dst.FieldList{
655+
List: []*dst.Field{{
656+
Names: []*dst.Ident{ident},
657+
Type: ast.Ident("any"), // Type constraint for the parameter
658+
}},
659+
}
660+
}
661+
case *dst.IndexListExpr:
662+
// GenStruct[T, U, ...] - multiple type parameters
663+
fields := make([]*dst.Field, 0, len(t.Indices))
664+
for _, idx := range t.Indices {
665+
if ident, ok := idx.(*dst.Ident); ok {
666+
fields = append(fields, &dst.Field{
667+
Names: []*dst.Ident{ident},
668+
Type: ast.Ident("any"), // Type constraint for the parameter
669+
})
670+
}
671+
}
672+
if len(fields) > 0 {
673+
return &dst.FieldList{List: fields}
674+
}
675+
}
676+
return nil
677+
}
678+
679+
// combineTypeParams combines type parameters from the receiver and function type parameters.
680+
// For methods on generic types, it extracts type parameters from the receiver and merges
681+
// them with the function's type parameters.
682+
// Receiver type parameters come first, followed by function type parameters.
683+
//
684+
// Example:
685+
//
686+
// Original: func (c *Container[K]) Transform[V any]() V
687+
// Result: [K, V]
688+
//
689+
// Generated trampolines:
690+
// func OtelBeforeTrampoline_Container_Transform[K comparable, V any](
691+
// hookContext *HookContext,
692+
// recv0 *Container[K], // ← Uses K
693+
// ) { ... }
694+
//
695+
// func OtelAfterTrampoline_Container_Transform[K comparable, V any](
696+
// hookContext *HookContext,
697+
// arg0 *V, // ← Uses V (return type)
698+
// ) { ... }
699+
func combineTypeParams(targetFunc *dst.FuncDecl) *dst.FieldList {
700+
var trampolineTypeParams *dst.FieldList
701+
if ast.HasReceiver(targetFunc) {
702+
receiverTypeParams := extractReceiverTypeParams(targetFunc.Recv.List[0].Type)
703+
if receiverTypeParams != nil {
704+
trampolineTypeParams = receiverTypeParams
705+
}
706+
}
707+
if targetFunc.Type.TypeParams != nil {
708+
if trampolineTypeParams == nil {
709+
trampolineTypeParams = targetFunc.Type.TypeParams
710+
} else {
711+
combined := &dst.FieldList{List: make([]*dst.Field, 0)}
712+
combined.List = append(combined.List, trampolineTypeParams.List...)
713+
combined.List = append(combined.List, targetFunc.Type.TypeParams.List...)
714+
trampolineTypeParams = combined
715+
}
716+
}
717+
return trampolineTypeParams
718+
}
719+
720+
// replaceGenericInstantiations replaces generic type instantiations with interface{}
721+
// For example: *GenStruct[T] -> *interface{}, []GenStruct[T, U] -> []interface{}
722+
func replaceGenericInstantiations(t dst.Expr) dst.Expr {
723+
switch tType := t.(type) {
724+
case *dst.StarExpr:
725+
// *GenStruct[T] -> *interface{}
726+
return ast.DereferenceOf(replaceGenericInstantiations(tType.X))
727+
case *dst.ArrayType:
728+
// []GenStruct[T] -> []interface{}
729+
return ast.ArrayType(replaceGenericInstantiations(tType.Elt))
730+
case *dst.MapType:
731+
// map[K]GenStruct[T] -> map[interface{}]interface{}
732+
return &dst.MapType{
733+
Key: replaceGenericInstantiations(tType.Key),
734+
Value: replaceGenericInstantiations(tType.Value),
735+
}
736+
case *dst.IndexExpr:
737+
// GenStruct[T] -> interface{}
738+
return ast.InterfaceType()
739+
case *dst.IndexListExpr:
740+
// GenStruct[T, U] -> interface{}
741+
return ast.InterfaceType()
742+
}
743+
return t
744+
}
745+
643746
// desugarType desugars parameter type to its original type, if parameter
644747
// is type of ...T, it will be converted to []T
645748
func desugarType(param *dst.Field) dst.Expr {
@@ -677,10 +780,22 @@ func (ip *InstrumentPhase) rewriteHookContext() {
677780
methodGetParamBody := findSwitchBlock(methodGetParam, 0)
678781
methodSetRetValBody := findSwitchBlock(methodSetRetVal, 1)
679782
methodGetRetValBody := findSwitchBlock(methodGetRetVal, 0)
783+
784+
combinedTypeParams := combineTypeParams(ip.targetFunc)
785+
786+
ip.rewriteHookContextParams(methodSetParamBody, methodGetParamBody, combinedTypeParams)
787+
ip.rewriteHookContextResults(methodSetRetValBody, methodGetRetValBody, combinedTypeParams)
788+
}
789+
790+
func (ip *InstrumentPhase) rewriteHookContextParams(
791+
methodSetParamBody, methodGetParamBody *dst.BlockStmt,
792+
combinedTypeParams *dst.FieldList,
793+
) {
680794
idx := 0
681795
if ast.HasReceiver(ip.targetFunc) {
682796
splitRecv := ast.SplitMultiNameFields(ip.targetFunc.Recv)
683-
recvType := replaceTypeParamsWithAny(splitRecv.List[0].Type, ip.targetFunc.Type.TypeParams)
797+
recvType := replaceGenericInstantiations(splitRecv.List[0].Type)
798+
recvType = replaceTypeParamsWithAny(recvType, combinedTypeParams)
684799
clause := setParamClause(idx, recvType)
685800
methodSetParamBody.List = append(methodSetParamBody.List, clause)
686801
clause = getParamClause(idx, recvType)
@@ -689,19 +804,24 @@ func (ip *InstrumentPhase) rewriteHookContext() {
689804
}
690805
splitParams := ast.SplitMultiNameFields(ip.targetFunc.Type.Params)
691806
for _, param := range splitParams.List {
692-
paramType := replaceTypeParamsWithAny(desugarType(param), ip.targetFunc.Type.TypeParams)
807+
paramType := replaceTypeParamsWithAny(desugarType(param), combinedTypeParams)
693808
clause := setParamClause(idx, paramType)
694809
methodSetParamBody.List = append(methodSetParamBody.List, clause)
695810
clause = getParamClause(idx, paramType)
696811
methodGetParamBody.List = append(methodGetParamBody.List, clause)
697812
idx++
698813
}
699-
// Rewrite GetReturnVal and SetReturnVal methods
814+
}
815+
816+
func (ip *InstrumentPhase) rewriteHookContextResults(
817+
methodSetRetValBody, methodGetRetValBody *dst.BlockStmt,
818+
combinedTypeParams *dst.FieldList,
819+
) {
700820
if ip.targetFunc.Type.Results != nil {
701-
idx = 0
821+
idx := 0
702822
splitResults := ast.SplitMultiNameFields(ip.targetFunc.Type.Results)
703823
for _, retval := range splitResults.List {
704-
retType := replaceTypeParamsWithAny(desugarType(retval), ip.targetFunc.Type.TypeParams)
824+
retType := replaceTypeParamsWithAny(desugarType(retval), combinedTypeParams)
705825
clause := getReturnValClause(idx, retType)
706826
methodGetRetValBody.List = append(methodGetRetValBody.List, clause)
707827
clause = setReturnValClause(idx, retType)
@@ -752,6 +872,13 @@ func replaceTypeParamsWithAny(t dst.Expr, typeParams *dst.FieldList) dst.Expr {
752872
Key: replaceTypeParamsWithAny(tType.Key, typeParams),
753873
Value: replaceTypeParamsWithAny(tType.Value, typeParams),
754874
}
875+
case *dst.IndexExpr:
876+
// GenStruct[T] -> interface{} (for generic receiver methods)
877+
// The hook function expects interface{} for generic types
878+
return ast.InterfaceType()
879+
case *dst.IndexListExpr:
880+
// GenStruct[T, U] -> interface{} (for generic receiver methods with multiple type params)
881+
return ast.InterfaceType()
755882
}
756883
return t
757884
}

0 commit comments

Comments
 (0)