@@ -50,96 +50,75 @@ func assertIndexListExprCall(
5050 assert .Len (t , expr .Args , expectedArgCount )
5151}
5252
53+ // Helper function to parse a complete function and extract its type parameters
54+ func parseFuncTypeParams (t * testing.T , funcSource string ) * dst.FieldList {
55+ parser := NewAstParser ()
56+ file , err := parser .ParseSource ("package main\n " + funcSource )
57+ require .NoError (t , err )
58+ require .Len (t , file .Decls , 1 )
59+ funcDecl , ok := file .Decls [0 ].(* dst.FuncDecl )
60+ require .True (t , ok )
61+ return funcDecl .Type .TypeParams
62+ }
63+
5364func TestCallToGeneric (t * testing.T ) {
5465 tests := []struct {
5566 name string
5667 funcName string
57- typeParams * dst. FieldList
68+ funcSource string // Source code for parsing type params
5869 args []dst.Expr
5970 validate func (* testing.T , * dst.CallExpr )
6071 }{
6172 {
6273 name : "nil type params returns simple call" ,
6374 funcName : "Foo" ,
64- typeParams : nil ,
75+ funcSource : "func Foo(x, y int) {}" , // No type params
6576 args : []dst.Expr {Ident ("x" ), Ident ("y" )},
6677 validate : func (t * testing.T , expr * dst.CallExpr ) {
6778 assertSimpleCall (t , expr , "Foo" , 2 )
6879 },
6980 },
7081 {
71- name : "single type parameter creates IndexExpr" ,
72- funcName : "GenericFunc" ,
73- typeParams : & dst.FieldList {
74- List : []* dst.Field {
75- {
76- Names : []* dst.Ident {Ident ("T" )},
77- Type : Ident ("any" ),
78- },
79- },
80- },
81- args : []dst.Expr {Ident ("value" )},
82+ name : "single type parameter creates IndexExpr" ,
83+ funcName : "GenericFunc" ,
84+ funcSource : "func GenericFunc[T any](value T) {}" ,
85+ args : []dst.Expr {Ident ("value" )},
8286 validate : func (t * testing.T , expr * dst.CallExpr ) {
8387 assertIndexExprCall (t , expr , "GenericFunc" , "T" , 1 )
8488 },
8589 },
8690 {
87- name : "multiple type parameters creates IndexListExpr" ,
88- funcName : "MultiGeneric" ,
89- typeParams : & dst.FieldList {
90- List : []* dst.Field {
91- {
92- Names : []* dst.Ident {Ident ("T" )},
93- Type : Ident ("any" ),
94- },
95- {
96- Names : []* dst.Ident {Ident ("U" )},
97- Type : Ident ("comparable" ),
98- },
99- },
100- },
101- args : []dst.Expr {Ident ("x" ), Ident ("y" )},
91+ name : "multiple type parameters creates IndexListExpr" ,
92+ funcName : "MultiGeneric" ,
93+ funcSource : "func MultiGeneric[T any, U comparable](x T, y U) {}" ,
94+ args : []dst.Expr {Ident ("x" ), Ident ("y" )},
10295 validate : func (t * testing.T , expr * dst.CallExpr ) {
10396 assertIndexListExprCall (t , expr , "MultiGeneric" , []string {"T" , "U" }, 2 )
10497 },
10598 },
10699 {
107- name : "field with multiple names creates multiple indices" ,
108- funcName : "MultiNameGeneric" ,
109- typeParams : & dst.FieldList {
110- List : []* dst.Field {
111- {
112- Names : []* dst.Ident {Ident ("T" ), Ident ("U" )},
113- Type : Ident ("any" ),
114- },
115- },
116- },
117- args : []dst.Expr {Ident ("value" )},
100+ name : "field with multiple names creates multiple indices" ,
101+ funcName : "MultiNameGeneric" ,
102+ funcSource : "func MultiNameGeneric[T, U any](value T) {}" ,
103+ args : []dst.Expr {Ident ("value" )},
118104 validate : func (t * testing.T , expr * dst.CallExpr ) {
119105 assertIndexListExprCall (t , expr , "MultiNameGeneric" , []string {"T" , "U" }, 1 )
120106 },
121107 },
122108 {
123- name : "no arguments with type parameters" ,
124- funcName : "NoArgsGeneric" ,
125- typeParams : & dst.FieldList {
126- List : []* dst.Field {
127- {
128- Names : []* dst.Ident {Ident ("T" )},
129- Type : Ident ("any" ),
130- },
131- },
132- },
133- args : []dst.Expr {},
109+ name : "no arguments with type parameters" ,
110+ funcName : "NoArgsGeneric" ,
111+ funcSource : "func NoArgsGeneric[T any]() {}" ,
112+ args : []dst.Expr {},
134113 validate : func (t * testing.T , expr * dst.CallExpr ) {
135114 assertIndexExprCall (t , expr , "NoArgsGeneric" , "T" , 0 )
136115 },
137116 },
138117 }
139-
140118 for _ , tt := range tests {
141119 t .Run (tt .name , func (t * testing.T ) {
142- result := CallToGeneric (tt .funcName , tt .typeParams , tt .args )
120+ typeParams := parseFuncTypeParams (t , tt .funcSource )
121+ result := CallToGeneric (tt .funcName , typeParams , tt .args )
143122 require .NotNil (t , result )
144123 tt .validate (t , result )
145124 })
0 commit comments