Skip to content

Commit ec25e1f

Browse files
committed
Added AsIn to mark struct as In.
1 parent 7709124 commit ec25e1f

File tree

2 files changed

+133
-14
lines changed

2 files changed

+133
-14
lines changed

dig_test.go

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,39 +264,63 @@ func TestEndToEndSuccess(t *testing.T) {
264264
})
265265

266266
t.Run("param recurse", func(t *testing.T) {
267-
type anotherParam struct {
267+
type anotherParamEmbedded struct {
268268
dig.In
269269

270270
Buffer *bytes.Buffer
271271
}
272272

273-
type someParam struct {
274-
dig.In
273+
type anotherParam struct {
274+
Reader *bytes.Reader
275+
}
275276

276-
Buffer *bytes.Buffer
277-
Another anotherParam
277+
type someParam struct {
278+
Buffer *bytes.Buffer
279+
Reader *bytes.Reader
280+
AnotherEmbedded anotherParamEmbedded
281+
Another anotherParam
278282
}
279283

280284
var (
281-
buff *bytes.Buffer
282-
called bool
285+
buff *bytes.Buffer
286+
reader *bytes.Reader
287+
calledBuffer bool
288+
calledReader bool
283289
)
284290

285291
c := digtest.New(t)
286292
c.RequireProvide(func() *bytes.Buffer {
287-
require.False(t, called, "constructor must be called exactly once")
288-
called = true
293+
require.False(t, calledBuffer, "constructor must be calledBuffer exactly once")
294+
calledBuffer = true
289295
buff = new(bytes.Buffer)
290296
return buff
291297
})
292298

299+
c.RequireProvide(func() *bytes.Reader {
300+
require.False(t, calledReader, "constructor must be calledBuffer exactly once")
301+
calledReader = true
302+
reader = new(bytes.Reader)
303+
return reader
304+
})
305+
306+
c.RequireProvide(dig.AsIn(someParam{}))
307+
c.RequireProvide(dig.AsIn(reflect.TypeOf(anotherParam{})))
308+
293309
c.RequireInvoke(func(p someParam) {
294-
require.True(t, called, "constructor must be called first")
310+
require.True(t, calledReader, "constructor must be calledBuffer first")
311+
require.True(t, calledReader, "constructor must be calledReader first")
312+
313+
require.NotNil(t, p.Buffer, "someParam.Reader must not be nil")
314+
require.NotNil(t, p.Reader, "someParam.Reader must not be nil")
315+
316+
require.NotNil(t, p.Another.Reader, "anotherParam.Reader must not be nil")
317+
require.True(t, p.Reader == p.Another.Reader, "readers fields must match")
318+
319+
require.True(t, p.Reader == reader, "buffer must match constructor's return value")
295320

296-
require.NotNil(t, p.Buffer, "someParam.Buffer must not be nil")
297-
require.NotNil(t, p.Another.Buffer, "anotherParam.Buffer must not be nil")
321+
require.NotNil(t, p.AnotherEmbedded.Buffer, "anotherParamEmbedded.Reader must not be nil")
322+
require.True(t, p.Buffer == p.AnotherEmbedded.Buffer, "buffers fields must match")
298323

299-
require.True(t, p.Buffer == p.Another.Buffer, "buffers fields must match")
300324
require.True(t, p.Buffer == buff, "buffer must match constructor's return value")
301325
})
302326
})
@@ -638,6 +662,13 @@ func TestEndToEndSuccess(t *testing.T) {
638662
A1 A `name:"first"` // should come from ret1 through ret2
639663
A2 A `name:"second"` // should come from ret2
640664
}
665+
666+
type paramAsIn struct {
667+
A1 A `name:"first"` // should come from ret1 through ret2
668+
A2 A `name:"second"` // should come from ret2
669+
}
670+
c.RequireProvide(dig.AsIn(paramAsIn{}))
671+
641672
c.RequireProvide(func() Ret2 {
642673
return Ret2{
643674
Ret1: Ret1{
@@ -651,6 +682,11 @@ func TestEndToEndSuccess(t *testing.T) {
651682
assert.Equal(t, 1, p.A1.idx)
652683
assert.Equal(t, 2, p.A2.idx)
653684
})
685+
686+
c.RequireInvoke(func(p paramAsIn) {
687+
assert.Equal(t, 1, p.A1.idx)
688+
assert.Equal(t, 2, p.A2.idx)
689+
})
654690
})
655691

656692
t.Run("named instances do not cause cycles", func(t *testing.T) {
@@ -709,7 +745,7 @@ func TestEndToEndSuccess(t *testing.T) {
709745

710746
require.Error(t, c.Invoke(func(*bytes.Buffer) {
711747
t.Fatalf("must not be called")
712-
}), "must not have a *bytes.Buffer in the container")
748+
}), "must not have a *bytes.Reader in the container")
713749
})
714750

715751
t.Run("As with Name", func(t *testing.T) {

inout.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,89 @@ func embedsType(i interface{}, e reflect.Type) bool {
158158
return false
159159
}
160160

161+
// AsIn marks struct as In by creating reflect.StructOf.
162+
func AsIn(i any) any {
163+
t, ok := inType(i)
164+
if !ok {
165+
return nil
166+
}
167+
168+
embeddingType := reflect.TypeOf(embeddingIn(t))
169+
fnType := reflect.FuncOf([]reflect.Type{embeddingType}, []reflect.Type{t}, false)
170+
171+
fn := reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value {
172+
in := args[0]
173+
out := reflect.New(t).Elem()
174+
175+
outIndex := 0
176+
for inIndex := 0; inIndex < in.NumField(); inIndex++ {
177+
if in.Field(inIndex).Type() == _inType {
178+
continue
179+
}
180+
181+
out.Field(outIndex).Set(in.Field(inIndex))
182+
outIndex++
183+
}
184+
185+
return []reflect.Value{out}
186+
})
187+
188+
return fn.Interface()
189+
}
190+
191+
func embeddingIn(t reflect.Type) any {
192+
return embedding(t, "In", _inType)
193+
}
194+
195+
func embedding(i any, name string, _type reflect.Type) any {
196+
t, ok := inType(i)
197+
if !ok {
198+
return nil
199+
}
200+
201+
if t.Kind() == reflect.Ptr {
202+
t = t.Elem()
203+
}
204+
205+
if t.Kind() != reflect.Struct {
206+
return nil
207+
}
208+
209+
// Build fields: start with embedded In
210+
fields := make([]reflect.StructField, 0, t.NumField()+1)
211+
fields = append(fields, reflect.StructField{
212+
Name: name,
213+
Type: _type,
214+
Anonymous: true,
215+
})
216+
217+
// Add all original fields
218+
for i := 0; i < t.NumField(); i++ {
219+
f := t.Field(i)
220+
fields = append(fields, reflect.StructField{
221+
Name: f.Name,
222+
Type: f.Type,
223+
Tag: f.Tag,
224+
})
225+
}
226+
227+
newType := reflect.StructOf(fields)
228+
return reflect.New(newType).Elem().Interface()
229+
}
230+
231+
func inType(i any) (reflect.Type, bool) {
232+
if i == nil {
233+
return nil, false
234+
}
235+
236+
t, ok := i.(reflect.Type)
237+
if !ok {
238+
t = reflect.TypeOf(i)
239+
}
240+
241+
return t, true
242+
}
243+
161244
// Checks if a field of an In struct is optional.
162245
func isFieldOptional(f reflect.StructField) (bool, error) {
163246
tag := f.Tag.Get(_optionalTag)

0 commit comments

Comments
 (0)