@@ -473,8 +473,9 @@ func (ip *InstrumentPhase) buildTrampolineTypes() {
473473 }
474474 }
475475 addHookContext (afterHookFunc .Type .Params )
476- beforeHookFunc .Type .TypeParams = ast .CloneTypeParams (ip .targetFunc .Type .TypeParams )
477- afterHookFunc .Type .TypeParams = ast .CloneTypeParams (ip .targetFunc .Type .TypeParams )
476+ trampolineTypeParams := combineTypeParams (ip .targetFunc )
477+ beforeHookFunc .Type .TypeParams = ast .CloneTypeParams (trampolineTypeParams )
478+ afterHookFunc .Type .TypeParams = ast .CloneTypeParams (trampolineTypeParams )
478479}
479480
480481func assignString (assignStmt * dst.AssignStmt , val string ) bool {
@@ -640,6 +641,108 @@ func setReturnValClause(idx int, t dst.Expr) *dst.CaseClause {
640641 return setValue (trampolineReturnValsIdentifier , idx , t )
641642}
642643
644+ // extractReceiverTypeParams extracts type parameters from a receiver type expression
645+ // For example: *GenStruct[T] or GenStruct[T, U] -> FieldList with T and U as type parameters
646+ func extractReceiverTypeParams (recvType dst.Expr ) * dst.FieldList {
647+ switch t := recvType .(type ) {
648+ case * dst.StarExpr :
649+ // *GenStruct[T] - recurse into X
650+ return extractReceiverTypeParams (t .X )
651+ case * dst.IndexExpr :
652+ // GenStruct[T] - single type parameter
653+ if ident , ok := t .Index .(* dst.Ident ); ok {
654+ return & dst.FieldList {
655+ List : []* dst.Field {{
656+ Names : []* dst.Ident {ident },
657+ Type : ast .Ident ("any" ), // Type constraint for the parameter
658+ }},
659+ }
660+ }
661+ case * dst.IndexListExpr :
662+ // GenStruct[T, U, ...] - multiple type parameters
663+ fields := make ([]* dst.Field , 0 , len (t .Indices ))
664+ for _ , idx := range t .Indices {
665+ if ident , ok := idx .(* dst.Ident ); ok {
666+ fields = append (fields , & dst.Field {
667+ Names : []* dst.Ident {ident },
668+ Type : ast .Ident ("any" ), // Type constraint for the parameter
669+ })
670+ }
671+ }
672+ if len (fields ) > 0 {
673+ return & dst.FieldList {List : fields }
674+ }
675+ }
676+ return nil
677+ }
678+
679+ // combineTypeParams combines type parameters from the receiver and function type parameters.
680+ // For methods on generic types, it extracts type parameters from the receiver and merges
681+ // them with the function's type parameters.
682+ // Receiver type parameters come first, followed by function type parameters.
683+ //
684+ // Example:
685+ //
686+ // Original: func (c *Container[K]) Transform[V any]() V
687+ // Result: [K, V]
688+ //
689+ // Generated trampolines:
690+ // func OtelBeforeTrampoline_Container_Transform[K comparable, V any](
691+ // hookContext *HookContext,
692+ // recv0 *Container[K], // ← Uses K
693+ // ) { ... }
694+ //
695+ // func OtelAfterTrampoline_Container_Transform[K comparable, V any](
696+ // hookContext *HookContext,
697+ // arg0 *V, // ← Uses V (return type)
698+ // ) { ... }
699+ func combineTypeParams (targetFunc * dst.FuncDecl ) * dst.FieldList {
700+ var trampolineTypeParams * dst.FieldList
701+ if ast .HasReceiver (targetFunc ) {
702+ receiverTypeParams := extractReceiverTypeParams (targetFunc .Recv .List [0 ].Type )
703+ if receiverTypeParams != nil {
704+ trampolineTypeParams = receiverTypeParams
705+ }
706+ }
707+ if targetFunc .Type .TypeParams != nil {
708+ if trampolineTypeParams == nil {
709+ trampolineTypeParams = targetFunc .Type .TypeParams
710+ } else {
711+ combined := & dst.FieldList {List : make ([]* dst.Field , 0 )}
712+ combined .List = append (combined .List , trampolineTypeParams .List ... )
713+ combined .List = append (combined .List , targetFunc .Type .TypeParams .List ... )
714+ trampolineTypeParams = combined
715+ }
716+ }
717+ return trampolineTypeParams
718+ }
719+
720+ // replaceGenericInstantiations replaces generic type instantiations with interface{}
721+ // For example: *GenStruct[T] -> *interface{}, []GenStruct[T, U] -> []interface{}
722+ func replaceGenericInstantiations (t dst.Expr ) dst.Expr {
723+ switch tType := t .(type ) {
724+ case * dst.StarExpr :
725+ // *GenStruct[T] -> *interface{}
726+ return ast .DereferenceOf (replaceGenericInstantiations (tType .X ))
727+ case * dst.ArrayType :
728+ // []GenStruct[T] -> []interface{}
729+ return ast .ArrayType (replaceGenericInstantiations (tType .Elt ))
730+ case * dst.MapType :
731+ // map[K]GenStruct[T] -> map[interface{}]interface{}
732+ return & dst.MapType {
733+ Key : replaceGenericInstantiations (tType .Key ),
734+ Value : replaceGenericInstantiations (tType .Value ),
735+ }
736+ case * dst.IndexExpr :
737+ // GenStruct[T] -> interface{}
738+ return ast .InterfaceType ()
739+ case * dst.IndexListExpr :
740+ // GenStruct[T, U] -> interface{}
741+ return ast .InterfaceType ()
742+ }
743+ return t
744+ }
745+
643746// desugarType desugars parameter type to its original type, if parameter
644747// is type of ...T, it will be converted to []T
645748func desugarType (param * dst.Field ) dst.Expr {
@@ -677,10 +780,22 @@ func (ip *InstrumentPhase) rewriteHookContext() {
677780 methodGetParamBody := findSwitchBlock (methodGetParam , 0 )
678781 methodSetRetValBody := findSwitchBlock (methodSetRetVal , 1 )
679782 methodGetRetValBody := findSwitchBlock (methodGetRetVal , 0 )
783+
784+ combinedTypeParams := combineTypeParams (ip .targetFunc )
785+
786+ ip .rewriteHookContextParams (methodSetParamBody , methodGetParamBody , combinedTypeParams )
787+ ip .rewriteHookContextResults (methodSetRetValBody , methodGetRetValBody , combinedTypeParams )
788+ }
789+
790+ func (ip * InstrumentPhase ) rewriteHookContextParams (
791+ methodSetParamBody , methodGetParamBody * dst.BlockStmt ,
792+ combinedTypeParams * dst.FieldList ,
793+ ) {
680794 idx := 0
681795 if ast .HasReceiver (ip .targetFunc ) {
682796 splitRecv := ast .SplitMultiNameFields (ip .targetFunc .Recv )
683- recvType := replaceTypeParamsWithAny (splitRecv .List [0 ].Type , ip .targetFunc .Type .TypeParams )
797+ recvType := replaceGenericInstantiations (splitRecv .List [0 ].Type )
798+ recvType = replaceTypeParamsWithAny (recvType , combinedTypeParams )
684799 clause := setParamClause (idx , recvType )
685800 methodSetParamBody .List = append (methodSetParamBody .List , clause )
686801 clause = getParamClause (idx , recvType )
@@ -689,19 +804,24 @@ func (ip *InstrumentPhase) rewriteHookContext() {
689804 }
690805 splitParams := ast .SplitMultiNameFields (ip .targetFunc .Type .Params )
691806 for _ , param := range splitParams .List {
692- paramType := replaceTypeParamsWithAny (desugarType (param ), ip . targetFunc . Type . TypeParams )
807+ paramType := replaceTypeParamsWithAny (desugarType (param ), combinedTypeParams )
693808 clause := setParamClause (idx , paramType )
694809 methodSetParamBody .List = append (methodSetParamBody .List , clause )
695810 clause = getParamClause (idx , paramType )
696811 methodGetParamBody .List = append (methodGetParamBody .List , clause )
697812 idx ++
698813 }
699- // Rewrite GetReturnVal and SetReturnVal methods
814+ }
815+
816+ func (ip * InstrumentPhase ) rewriteHookContextResults (
817+ methodSetRetValBody , methodGetRetValBody * dst.BlockStmt ,
818+ combinedTypeParams * dst.FieldList ,
819+ ) {
700820 if ip .targetFunc .Type .Results != nil {
701- idx = 0
821+ idx : = 0
702822 splitResults := ast .SplitMultiNameFields (ip .targetFunc .Type .Results )
703823 for _ , retval := range splitResults .List {
704- retType := replaceTypeParamsWithAny (desugarType (retval ), ip . targetFunc . Type . TypeParams )
824+ retType := replaceTypeParamsWithAny (desugarType (retval ), combinedTypeParams )
705825 clause := getReturnValClause (idx , retType )
706826 methodGetRetValBody .List = append (methodGetRetValBody .List , clause )
707827 clause = setReturnValClause (idx , retType )
@@ -752,6 +872,13 @@ func replaceTypeParamsWithAny(t dst.Expr, typeParams *dst.FieldList) dst.Expr {
752872 Key : replaceTypeParamsWithAny (tType .Key , typeParams ),
753873 Value : replaceTypeParamsWithAny (tType .Value , typeParams ),
754874 }
875+ case * dst.IndexExpr :
876+ // GenStruct[T] -> interface{} (for generic receiver methods)
877+ // The hook function expects interface{} for generic types
878+ return ast .InterfaceType ()
879+ case * dst.IndexListExpr :
880+ // GenStruct[T, U] -> interface{} (for generic receiver methods with multiple type params)
881+ return ast .InterfaceType ()
755882 }
756883 return t
757884}
0 commit comments