Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,39 +264,63 @@ func TestEndToEndSuccess(t *testing.T) {
})

t.Run("param recurse", func(t *testing.T) {
type anotherParam struct {
type anotherParamEmbedded struct {
dig.In

Buffer *bytes.Buffer
}

type someParam struct {
dig.In
type anotherParam struct {
Reader *bytes.Reader
}

Buffer *bytes.Buffer
Another anotherParam
type someParam struct {
Buffer *bytes.Buffer
Reader *bytes.Reader
AnotherEmbedded anotherParamEmbedded
Another anotherParam
}

var (
buff *bytes.Buffer
called bool
buff *bytes.Buffer
reader *bytes.Reader
calledBuffer bool
calledReader bool
)

c := digtest.New(t)
c.RequireProvide(func() *bytes.Buffer {
require.False(t, called, "constructor must be called exactly once")
called = true
require.False(t, calledBuffer, "constructor must be calledBuffer exactly once")
calledBuffer = true
buff = new(bytes.Buffer)
return buff
})

c.RequireProvide(func() *bytes.Reader {
require.False(t, calledReader, "constructor must be calledBuffer exactly once")
calledReader = true
reader = new(bytes.Reader)
return reader
})

c.RequireProvide(dig.AsIn(someParam{}))
c.RequireProvide(dig.AsIn(reflect.TypeOf(anotherParam{})))

c.RequireInvoke(func(p someParam) {
require.True(t, called, "constructor must be called first")
require.True(t, calledReader, "constructor must be calledBuffer first")
require.True(t, calledReader, "constructor must be calledReader first")

require.NotNil(t, p.Buffer, "someParam.Reader must not be nil")
require.NotNil(t, p.Reader, "someParam.Reader must not be nil")

require.NotNil(t, p.Another.Reader, "anotherParam.Reader must not be nil")
require.True(t, p.Reader == p.Another.Reader, "readers fields must match")

require.True(t, p.Reader == reader, "buffer must match constructor's return value")

require.NotNil(t, p.Buffer, "someParam.Buffer must not be nil")
require.NotNil(t, p.Another.Buffer, "anotherParam.Buffer must not be nil")
require.NotNil(t, p.AnotherEmbedded.Buffer, "anotherParamEmbedded.Reader must not be nil")
require.True(t, p.Buffer == p.AnotherEmbedded.Buffer, "buffers fields must match")

require.True(t, p.Buffer == p.Another.Buffer, "buffers fields must match")
require.True(t, p.Buffer == buff, "buffer must match constructor's return value")
})
})
Expand Down Expand Up @@ -638,6 +662,13 @@ func TestEndToEndSuccess(t *testing.T) {
A1 A `name:"first"` // should come from ret1 through ret2
A2 A `name:"second"` // should come from ret2
}

type paramAsIn struct {
A1 A `name:"first"` // should come from ret1 through ret2
A2 A `name:"second"` // should come from ret2
}
c.RequireProvide(dig.AsIn(paramAsIn{}))

c.RequireProvide(func() Ret2 {
return Ret2{
Ret1: Ret1{
Expand All @@ -651,6 +682,11 @@ func TestEndToEndSuccess(t *testing.T) {
assert.Equal(t, 1, p.A1.idx)
assert.Equal(t, 2, p.A2.idx)
})

c.RequireInvoke(func(p paramAsIn) {
assert.Equal(t, 1, p.A1.idx)
assert.Equal(t, 2, p.A2.idx)
})
})

t.Run("named instances do not cause cycles", func(t *testing.T) {
Expand Down Expand Up @@ -709,7 +745,7 @@ func TestEndToEndSuccess(t *testing.T) {

require.Error(t, c.Invoke(func(*bytes.Buffer) {
t.Fatalf("must not be called")
}), "must not have a *bytes.Buffer in the container")
}), "must not have a *bytes.Reader in the container")
})

t.Run("As with Name", func(t *testing.T) {
Expand Down
83 changes: 83 additions & 0 deletions inout.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,89 @@ func embedsType(i interface{}, e reflect.Type) bool {
return false
}

// AsIn marks struct as In by creating reflect.StructOf.
func AsIn(i any) any {
t, ok := inType(i)
if !ok {
return nil
}

embeddingType := reflect.TypeOf(embeddingIn(t))
fnType := reflect.FuncOf([]reflect.Type{embeddingType}, []reflect.Type{t}, false)

fn := reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value {
in := args[0]
out := reflect.New(t).Elem()

outIndex := 0
for inIndex := 0; inIndex < in.NumField(); inIndex++ {
if in.Field(inIndex).Type() == _inType {
continue
}

out.Field(outIndex).Set(in.Field(inIndex))
outIndex++
}

return []reflect.Value{out}
})

return fn.Interface()
}

func embeddingIn(t reflect.Type) any {
return embedding(t, "In", _inType)
}

func embedding(i any, name string, _type reflect.Type) any {
t, ok := inType(i)
if !ok {
return nil
}

if t.Kind() == reflect.Ptr {
t = t.Elem()
}

if t.Kind() != reflect.Struct {
return nil
}

// Build fields: start with embedded In
fields := make([]reflect.StructField, 0, t.NumField()+1)
fields = append(fields, reflect.StructField{
Name: name,
Type: _type,
Anonymous: true,
})

// Add all original fields
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
fields = append(fields, reflect.StructField{
Name: f.Name,
Type: f.Type,
Tag: f.Tag,
})
}

newType := reflect.StructOf(fields)
return reflect.New(newType).Elem().Interface()
}

func inType(i any) (reflect.Type, bool) {
if i == nil {
return nil, false
}

t, ok := i.(reflect.Type)
if !ok {
t = reflect.TypeOf(i)
}

return t, true
}

// Checks if a field of an In struct is optional.
func isFieldOptional(f reflect.StructField) (bool, error) {
tag := f.Tag.Get(_optionalTag)
Expand Down