Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anonymous field flatten mangler #88

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions decoders/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ const YAMLTagName = "yaml"

// Decoder is a decoder that knows how to work with text encoded in YAML.
type Decoder struct {
// Flatten any anonymous struct fields into the parent
FlattenAnonymous bool
}

// Decode reads from `r` and decodes what is read as YAML depositing the
Expand All @@ -29,9 +31,13 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
return reflect.Value{}, fmt.Errorf("error reading YAML: %s", err)
}

manglers := []transform.Mangler{&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: YAMLTagName}}
if d.FlattenAnonymous {
manglers = append(manglers, transform.AnonymousFlattenMangler{})
}
tfmr := transform.NewTransformer(t.Type(),
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: YAMLTagName},
manglers...,
)
val, tfmErr := tfmr.Translate()
if tfmErr != nil {
Expand Down
49 changes: 49 additions & 0 deletions decoders/yaml/yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/vimeo/dials"
"github.com/vimeo/dials/sources/static"
)
Expand Down Expand Up @@ -145,6 +146,54 @@ func TestMoreDeeplyNestedYAML(t *testing.T) {
assert.Equal(t, c.DatabaseUser.SliceThing[1].Zizzle, "fizzlebat")
}

func TestAnonymousNestedYAML(t *testing.T) {
type OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Timeout time.Duration `dials:"timeout"`
}
type testConfig struct {
DatabaseName string `dials:"database_name"`
DatabaseAddress string `dials:"database_address"`
DatabaseUser struct {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff
} `dials:"database_user"`
}

yamlData := `{
"database_name": "something",
"database_address": "127.0.0.1",
"database_user": {
"username": "test",
"password": "password",
"something": "asdf",
"ip_address": "123.10.11.121",
"timeout": "10s",
}
}`

myConfig := &testConfig{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
d, err := dials.Config(
ctx,
myConfig,
&static.StringSource{Data: yamlData, Decoder: &Decoder{FlattenAnonymous: true}},
)

require.NoError(t, err)
c := d.View()

assert.Equal(t, "something", c.DatabaseName)
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.Something)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.IPAddress)
assert.Equal(t, time.Duration(10*time.Second), c.DatabaseUser.Timeout)
}

func TestDecoderBadMarkup(t *testing.T) {
type testConfig struct {
Val1 string
Expand Down
45 changes: 41 additions & 4 deletions ez/ez.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,23 @@ type Params[T any] struct {
// Note that this does not affect the flags or environment variable
// naming. To manipulate flag naming, see [Params.FlagConfig].
FileFieldNameEncoder caseconversion.EncodeCasingFunc

// FlattenAnonymousFields inserts the AnonymousFlattenMangler into the
// chain so decoders that do not handle anonymous fields never see such
// things.
// (Currently only affects the yaml decoder)
FlattenAnonymousFields bool
}

// DecoderFactory should return the appropriate decoder based on the config file
// path that is passed as the string argument to DecoderFactory
type DecoderFactory func(string) dials.Decoder

// DecoderFactoryWithParams should return the appropriate decoder based on the config file
// path that is passed as the string argument to DecoderFactory
// Params may provide useful context/arguments
type DecoderFactoryWithParams[T any] func(string, Params[T]) dials.Decoder

// ConfigWithConfigPath is an interface config struct that supplies a
// ConfigPath() method to indicate which file to read as the config file once
// populated.
Expand Down Expand Up @@ -125,6 +136,25 @@ func fileSource(cfgPath string, decoder dials.Decoder, watch bool) (dials.Source
// The contents of cfg for the defaults
// cfg.ConfigPath() is evaluated on the stacked config with the file-contents omitted (using a "blank" source)
func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, df DecoderFactory, params Params[T]) (*dials.Dials[T], error) {
dfp := func(path string, _ Params[T]) dials.Decoder {
return df(path)
}
return ConfigFileEnvFlagDecoderFactoryParams(ctx, cfg, dfp, params)

}

// ConfigFileEnvFlagDecoderFactoryParams takes advantage of the ConfigWithConfigPath cfg to indicate
// what file to read and uses the passed decoder.
// Configuration values provided by the returned Dials are the result of
// stacking the sources in the following order:
// - configuration file
// - environment variables
// - flags it registers with the standard library flags package
//
// The contents of cfg for the defaults
// cfg.ConfigPath() is evaluated on the stacked config with the file-contents omitted (using a "blank" source)
// It differs from ConfigFileEnvFlag by the signature of the decoder factory, (which requires a params struct in this function)
func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, df DecoderFactoryWithParams[T], params Params[T]) (*dials.Dials[T], error) {
blank := sourcewrap.Blank{}

flagSrc := params.FlagSource
Expand Down Expand Up @@ -184,7 +214,7 @@ func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
return d, nil
}

decoder := df(cfgPath)
decoder := df(cfgPath, params)
if decoder == nil {
return nil, fmt.Errorf("decoderFactory provided a nil decoder for path: %s", cfgPath)
}
Expand Down Expand Up @@ -243,7 +273,7 @@ func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
// YAMLConfigEnvFlag takes advantage of the ConfigWithConfigPath cfg, thinly
// wraping ConfigFileEnvFlag with the decoder statically set to YAML.
func YAMLConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, params Params[T]) (*dials.Dials[T], error) {
return ConfigFileEnvFlag(ctx, cfg, func(string) dials.Decoder { return &yaml.Decoder{} }, params)
return ConfigFileEnvFlag(ctx, cfg, func(string) dials.Decoder { return &yaml.Decoder{FlattenAnonymous: params.FlattenAnonymousFields} }, params)
}

// JSONConfigEnvFlag takes advantage of the ConfigWithConfigPath cfg, thinly
Expand All @@ -268,10 +298,17 @@ func TOMLConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
// based on the extension of the filename or nil if there is not an appropriate
// mapping.
func DecoderFromExtension(path string) dials.Decoder {
return DecoderFromExtensionWithParams(path, Params[struct{}]{})
}

// DecoderFromExtension is a DecoderFactory that returns an appropriate decoder
// based on the extension of the filename or nil if there is not an appropriate
// mapping.
func DecoderFromExtensionWithParams[T any](path string, p Params[T]) dials.Decoder {
ext := filepath.Ext(path)
switch strings.ToLower(ext) {
case ".yaml", ".yml":
return &yaml.Decoder{}
return &yaml.Decoder{FlattenAnonymous: p.FlattenAnonymousFields}
case ".json":
return &json.Decoder{}
case ".toml":
Expand All @@ -289,5 +326,5 @@ func DecoderFromExtension(path string) dials.Decoder {
// file contents based on the file extension (from the limited set of JSON,
// Cue, YAML and TOML).
func FileExtensionDecoderConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, params Params[T]) (*dials.Dials[T], error) {
return ConfigFileEnvFlag(ctx, cfg, DecoderFromExtension, params)
return ConfigFileEnvFlagDecoderFactoryParams(ctx, cfg, DecoderFromExtensionWithParams[T], params)
}
117 changes: 117 additions & 0 deletions transform/anonymous_flatten_mangler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package transform

import (
"reflect"
)

// AnonymousFlattenMangler hoists the fields from the types of anonymous
// struct-fields into the parent type. (working around decoders/sources that
// are unaware of anonymous fields)
// Note: this mangler is unaware of TextUnmarshaler implementations (it's tricky to do right when flattening).
// It should be combined with the TextUnmarshalerMangler if the prefered
// handling is to mask the other fields in that struct with the TextUnmarshaler
// implementation.
type AnonymousFlattenMangler struct{}

// Mangle is called for every field in a struct, and maps that to one or more output fields.
// Implementations that desire to leave fields unchanged should return
// the argument unchanged. (particularly useful if taking advantage of
// recursive evaluation)
func (a AnonymousFlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField, error) {
// If it's not an anonymous field, return it as-is
if !sf.Anonymous {
return []reflect.StructField{sf}, nil
}
// Note: TranslateType already skips unexported fields

// anonymous/embedded fields can only be interfaces, pointers and structs
switch sf.Type.Kind() {
case reflect.Pointer:
// recurse with the pointer stripped off
sfInner := sf
sfInner.Type = sf.Type.Elem()
return a.Mangle(sfInner)
case reflect.Struct:
out := make([]reflect.StructField, 0, sf.Type.NumField())
for i := 0; i < sf.Type.NumField(); i++ {
innerField := sf.Type.Field(i)
if !innerField.IsExported() {
// skip unexported fields
continue
}
out = append(out, innerField)
}

return out, nil
default:
// leave everything else alone (there's nothing to promote)
// this includes interfaces and all other non-struct and
// non-pointer-to-struct types.
return []reflect.StructField{sf}, nil
}
}

// bool return value indicates whether all fields are nil (and as such, a nil value should be returned for pointer-types)
func (a AnonymousFlattenMangler) unmangleStruct(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, bool) {
out := reflect.New(sf.Type).Elem()
if len(fvs) == 0 {
// no fields made it, just return out.
return out, true
}
fvsIdx := 0
allNil := true
for i := 0; i < sf.Type.NumField(); i++ {
oft := sf.Type.Field(i)
if oft.Name == fvs[fvsIdx].Field.Name {
out.Field(i).Set(fvs[fvsIdx].Value)
switch fvs[fvsIdx].Value.Kind() {
// check for nil-able types
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface, reflect.Chan:
if !fvs[fvsIdx].Value.IsZero() {
allNil = false
}
default:
// non-nilable field, just assume it's non-nil
// pointerification shold have made this nilable, though.
allNil = false
}
fvsIdx++
}
}
return out, allNil
}

// Unmangle is called for every source-field->mangled-field
// mapping-set, with the mangled-field and its populated value set. The
// implementation of Unmangle should return a reflect.Value that will
// be used for the next mangler or final struct value)
// Returned reflect.Value should be convertible to the field's type.
func (a AnonymousFlattenMangler) Unmangle(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, error) {
if !sf.Anonymous {
// not anonymous, just forward the single value
return fvs[0].Value, nil
}
switch sf.Type.Kind() {
case reflect.Pointer:
// It's a pointer. check for nil; strip off the pointer and recurse
msf := sf
msf.Type = sf.Type.Elem()
v, allNil := a.unmangleStruct(msf, fvs)
if allNil {
return reflect.Zero(sf.Type), nil
}
return v.Addr(), nil
case reflect.Struct:
out, _ := a.unmangleStruct(sf, fvs)
return out, nil
default:
// not a struct-typed anonymous field, just forward up the chain
return fvs[0].Value, nil
}
}

// ShouldRecurse is called after Mangle for each field so nested struct
// fields get iterated over after any transformation done by Mangle().
func (a AnonymousFlattenMangler) ShouldRecurse(_ reflect.StructField) bool {
return true
}
Loading
Loading