diff --git a/decoders/cue/cue.go b/decoders/cue/cue.go index 8761452..8f3e1a1 100644 --- a/decoders/cue/cue.go +++ b/decoders/cue/cue.go @@ -4,11 +4,13 @@ import ( "fmt" "io" "reflect" + "time" "cuelang.org/go/cue/cuecontext" "github.com/vimeo/dials" "github.com/vimeo/dials/common" + "github.com/vimeo/dials/decoders/json/jsontypes" "github.com/vimeo/dials/tagformat" "github.com/vimeo/dials/transform" ) @@ -16,6 +18,18 @@ import ( // Decoder is a decoder that knows how to work with configs written in Cue type Decoder struct{} +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at +// package-scope, so we don't have to construct a new one every time Decode is +// called. +var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]()) + // Decode is a decoder that decodes the Cue config from an io.Reader into the // appropriate struct. func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) { @@ -27,7 +41,9 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) { const jsonTagName = "json" // If there aren't any json tags, copy over from any dials tags. + // Also, convert any time.Duration fields to jsontypes.ParsingDuration so we can decode those values as strings. tfmr := transform.NewTransformer(t.Type(), + parsingDurMangler, &tagformat.TagCopyingMangler{ SrcTag: common.DialsTagName, NewTag: jsonTagName}) reflVal, tfmErr := tfmr.Translate() @@ -43,5 +59,11 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) { if decErr := val.Decode(reflVal.Addr().Interface()); decErr != nil { return reflect.Value{}, fmt.Errorf("failed to decode cue value into dials struct: %w", decErr) } - return reflVal, nil + + unmangledVal, unmangleErr := tfmr.ReverseTranslate(reflVal) + if unmangleErr != nil { + return reflect.Value{}, unmangleErr + } + + return unmangledVal, nil } diff --git a/decoders/cue/cue_test.go b/decoders/cue/cue_test.go index a49e22f..d45cdab 100644 --- a/decoders/cue/cue_test.go +++ b/decoders/cue/cue_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -82,13 +83,17 @@ func TestDeeplyNestedCueJSON(t *testing.T) { Username string `dials:"username"` Password string `dials:"password"` OtherStuff struct { - Something string `dials:"something"` - IPAddress net.IP `dials:"ip_address"` + Something string `dials:"something"` + IPAddress net.IP `dials:"ip_address"` + SomeTimeout time.Duration `dials:"some_timeout"` + SomeOtherTimeout time.Duration `dials:"some_other_timeout"` + SomeLifetime time.Duration `dials:"some_lifetime_ns"` } `dials:"other_stuff"` } `dials:"database_user"` } - jsonData := `{ + cueData := ` + import "time" "database_name": "something", "database_address": "127.0.0.1", "database_user": { @@ -97,15 +102,18 @@ func TestDeeplyNestedCueJSON(t *testing.T) { "other_stuff": { "something": "asdf", "ip_address": "123.10.11.121" + "some_timeout": "13s" + "some_other_timeout": 87 * time.Second, + "some_lifetime_ns": 378, } } - }` + ` myConfig := &testConfig{} d, err := dials.Config( context.Background(), myConfig, - &static.StringSource{Data: jsonData, Decoder: &Decoder{}}, + &static.StringSource{Data: cueData, Decoder: &Decoder{}}, ) require.NoError(t, err) @@ -115,6 +123,9 @@ func TestDeeplyNestedCueJSON(t *testing.T) { assert.Equal(t, "test", c.DatabaseUser.Username) assert.Equal(t, "password", c.DatabaseUser.Password) assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something) + assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout) + assert.Equal(t, time.Second*87, c.DatabaseUser.OtherStuff.SomeOtherTimeout) + assert.Equal(t, time.Nanosecond*378, c.DatabaseUser.OtherStuff.SomeLifetime) assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress) } diff --git a/decoders/json/json.go b/decoders/json/json.go index b8d93de..88360fb 100644 --- a/decoders/json/json.go +++ b/decoders/json/json.go @@ -4,11 +4,12 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "reflect" + "time" "github.com/vimeo/dials" "github.com/vimeo/dials/common" + "github.com/vimeo/dials/decoders/json/jsontypes" "github.com/vimeo/dials/tagformat" "github.com/vimeo/dials/transform" ) @@ -17,19 +18,31 @@ import ( const JSONTagName = "json" // Decoder is a decoder that knows how to work with text encoded in JSON -type Decoder struct { +type Decoder struct{} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v } +// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at +// package-scope, so we don't have to construct a new one every time Decode is +// called. +var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]()) + // Decode is a decoder that decodes the JSON from an io.Reader into the // appropriate struct. func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) { - jsonBytes, err := ioutil.ReadAll(r) + jsonBytes, err := io.ReadAll(r) if err != nil { return reflect.Value{}, fmt.Errorf("error reading JSON: %s", err) } // If there aren't any json tags, copy over from any dials tags. tfmr := transform.NewTransformer(t.Type(), + parsingDurMangler, &tagformat.TagCopyingMangler{ SrcTag: common.DialsTagName, NewTag: JSONTagName}) val, tfmErr := tfmr.Translate() diff --git a/decoders/json/json_test.go b/decoders/json/json_test.go index d40272e..460f41a 100644 --- a/decoders/json/json_test.go +++ b/decoders/json/json_test.go @@ -4,9 +4,11 @@ import ( "context" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vimeo/dials" "github.com/vimeo/dials/sources/static" ) @@ -81,8 +83,9 @@ func TestDeeplyNestedJSON(t *testing.T) { Username string `dials:"username"` Password string `dials:"password"` OtherStuff struct { - Something string `dials:"something"` - IPAddress net.IP `dials:"ip_address"` + Something string `dials:"something"` + IPAddress net.IP `dials:"ip_address"` + SomeTimeout time.Duration `dials:"some_timeout"` } `dials:"other_stuff"` } `dials:"database_user"` } @@ -95,7 +98,8 @@ func TestDeeplyNestedJSON(t *testing.T) { "password": "password", "other_stuff": { "something": "asdf", - "ip_address": "123.10.11.121" + "ip_address": "123.10.11.121", + "some_timeout": "13s" } } }` @@ -114,6 +118,7 @@ func TestDeeplyNestedJSON(t *testing.T) { assert.Equal(t, "test", c.DatabaseUser.Username) assert.Equal(t, "password", c.DatabaseUser.Password) assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something) + assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout) assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress) } diff --git a/decoders/json/jsontypes/jsonduration.go b/decoders/json/jsontypes/jsonduration.go new file mode 100644 index 0000000..e198a90 --- /dev/null +++ b/decoders/json/jsontypes/jsonduration.go @@ -0,0 +1,42 @@ +// Package jsontypes contains helper types used by the JSON and Cue decoders to +// facilitate more natural decoding. +package jsontypes + +import ( + "bytes" + "encoding/json" + "fmt" + "time" +) + +// ParsingDuration implements [encoding/json.Unmarshaler], supporting both +// quoted strings that are parseable with [time.ParseDuration], and integer nanoseconds if it's a number +type ParsingDuration int64 + +// UnmarshalJSON implements [encoding/json.Unmarshaler] for ParsingDuration. +func (p *ParsingDuration) UnmarshalJSON(b []byte) error { + d := json.NewDecoder(bytes.NewReader(b)) + d.UseNumber() + n, tokErr := d.Token() + if tokErr != nil { + return fmt.Errorf("failed to parse token: %w", tokErr) + } + switch v := n.(type) { + case string: + dur, durParseErr := time.ParseDuration(v) + if durParseErr != nil { + return fmt.Errorf("failed to parse %q as duration: %w", v, durParseErr) + } + *p = ParsingDuration(dur) + return nil + case json.Number: + i, intParseErr := v.Int64() + if intParseErr != nil { + return fmt.Errorf("failed to parse number as integer nanoseconds: %w", intParseErr) + } + *p = ParsingDuration(i) + return nil + default: + return fmt.Errorf("unexpected JSON token-type %T; expected string or number", n) + } +} diff --git a/decoders/json/jsontypes/jsonduration_test.go b/decoders/json/jsontypes/jsonduration_test.go new file mode 100644 index 0000000..bdebf11 --- /dev/null +++ b/decoders/json/jsontypes/jsonduration_test.go @@ -0,0 +1,88 @@ +package jsontypes + +import ( + "encoding/json" + "reflect" + "testing" + "time" +) + +func ptrVal[V any](v V) *V { + return &v +} + +func TestParsingDuration(t *testing.T) { + t.Parallel() + type decStruct struct { + Dur ParsingDuration + DurPtr *ParsingDuration + } + for _, tbl := range []struct { + name string + inJSON string + expStruct any + expErr bool + }{ + { + name: "integer_value_no_ptr", + inJSON: `{"dur": 1234}`, + expStruct: decStruct{Dur: 1234}, + expErr: false, + }, + { + name: "string_value_no_ptr", + inJSON: `{"dur": "3s"}`, + expStruct: decStruct{Dur: ParsingDuration(3 * time.Second)}, + expErr: false, + }, + { + name: "string_value_ptr", + inJSON: `{"dur": "3s", "durptr": "9m"}`, + expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(9 * time.Minute))}, + expErr: false, + }, + { + name: "integer_value_ptr", + inJSON: `{"dur": "3s", "durptr": 2048}`, + expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(2048))}, + expErr: false, + }, + { + name: "error_array_val", + inJSON: `{"dur": [], "durptr": 2048}`, + expErr: true, + }, + { + name: "error_object_val", + inJSON: `{"dur": {}, "durptr": 2048}`, + expErr: true, + }, + { + name: "error_unparsable_str_val", + inJSON: `{"dur": "sssssssssss", "durptr": 2048}`, + expErr: true, + }, + { + name: "error_float_fractional_val", + inJSON: `{"dur": 0.333333, "durptr": 2048}`, + expErr: true, + }, + } { + t.Run(tbl.name, func(t *testing.T) { + v := decStruct{} + decErr := json.Unmarshal([]byte(tbl.inJSON), &v) + if decErr != nil { + if !tbl.expErr { + t.Errorf("unexpected error unmarshaling: %s", decErr) + } else { + t.Logf("expected error: %s", decErr) + } + return + } + if !reflect.DeepEqual(v, tbl.expStruct) { + t.Errorf("unexpected value\n got: %+v\nwant:%+v", tbl.expStruct, v) + } + }) + + } +}