Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

json/cue: substitute time.Duration with type using time.ParseDuration #92

Merged
merged 3 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion decoders/cue/cue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,32 @@ import (
"fmt"
"io"
"reflect"
"time"

"cuelang.org/go/cue/cuecontext"

"github.com/vimeo/dials"
"github.com/vimeo/dials/common"
"github.com/vimeo/dials/decoders/json/jsontypes"
"github.com/vimeo/dials/tagformat"
"github.com/vimeo/dials/transform"
)

// Decoder is a decoder that knows how to work with configs written in Cue
type Decoder struct{}

func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at
// package-scope, so we don't have to construct a new one every time Decode is
// called.
var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]())

// Decode is a decoder that decodes the Cue config from an io.Reader into the
// appropriate struct.
func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
Expand All @@ -27,7 +41,9 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
const jsonTagName = "json"

// If there aren't any json tags, copy over from any dials tags.
// Also, convert any time.Duration fields to jsontypes.ParsingDuration so we can decode those values as strings.
tfmr := transform.NewTransformer(t.Type(),
parsingDurMangler,
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: jsonTagName})
reflVal, tfmErr := tfmr.Translate()
Expand All @@ -43,5 +59,11 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
if decErr := val.Decode(reflVal.Addr().Interface()); decErr != nil {
return reflect.Value{}, fmt.Errorf("failed to decode cue value into dials struct: %w", decErr)
}
return reflVal, nil

unmangledVal, unmangleErr := tfmr.ReverseTranslate(reflVal)
if unmangleErr != nil {
return reflect.Value{}, unmangleErr
}

return unmangledVal, nil
}
21 changes: 16 additions & 5 deletions decoders/cue/cue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -82,13 +83,17 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
SomeTimeout time.Duration `dials:"some_timeout"`
SomeOtherTimeout time.Duration `dials:"some_other_timeout"`
SomeLifetime time.Duration `dials:"some_lifetime_ns"`
} `dials:"other_stuff"`
} `dials:"database_user"`
}

jsonData := `{
cueData := `
import "time"
"database_name": "something",
"database_address": "127.0.0.1",
"database_user": {
Expand All @@ -97,15 +102,18 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
"other_stuff": {
"something": "asdf",
"ip_address": "123.10.11.121"
"some_timeout": "13s"
"some_other_timeout": 87 * time.Second,
"some_lifetime_ns": 378,
}
}
}`
`

myConfig := &testConfig{}
d, err := dials.Config(
context.Background(),
myConfig,
&static.StringSource{Data: jsonData, Decoder: &Decoder{}},
&static.StringSource{Data: cueData, Decoder: &Decoder{}},
)
require.NoError(t, err)

Expand All @@ -115,6 +123,9 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something)
assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout)
assert.Equal(t, time.Second*87, c.DatabaseUser.OtherStuff.SomeOtherTimeout)
assert.Equal(t, time.Nanosecond*378, c.DatabaseUser.OtherStuff.SomeLifetime)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress)

}
Expand Down
19 changes: 16 additions & 3 deletions decoders/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"reflect"
"time"

"github.com/vimeo/dials"
"github.com/vimeo/dials/common"
"github.com/vimeo/dials/decoders/json/jsontypes"
"github.com/vimeo/dials/tagformat"
"github.com/vimeo/dials/transform"
)
Expand All @@ -17,19 +18,31 @@ import (
const JSONTagName = "json"

// Decoder is a decoder that knows how to work with text encoded in JSON
type Decoder struct {
type Decoder struct{}

func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at
// package-scope, so we don't have to construct a new one every time Decode is
// called.
var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]())

// Decode is a decoder that decodes the JSON from an io.Reader into the
// appropriate struct.
func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
jsonBytes, err := ioutil.ReadAll(r)
jsonBytes, err := io.ReadAll(r)
if err != nil {
return reflect.Value{}, fmt.Errorf("error reading JSON: %s", err)
}

// If there aren't any json tags, copy over from any dials tags.
tfmr := transform.NewTransformer(t.Type(),
parsingDurMangler,
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: JSONTagName})
val, tfmErr := tfmr.Translate()
Expand Down
11 changes: 8 additions & 3 deletions decoders/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/vimeo/dials"
"github.com/vimeo/dials/sources/static"
)
Expand Down Expand Up @@ -81,8 +83,9 @@ func TestDeeplyNestedJSON(t *testing.T) {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
SomeTimeout time.Duration `dials:"some_timeout"`
} `dials:"other_stuff"`
} `dials:"database_user"`
}
Expand All @@ -95,7 +98,8 @@ func TestDeeplyNestedJSON(t *testing.T) {
"password": "password",
"other_stuff": {
"something": "asdf",
"ip_address": "123.10.11.121"
"ip_address": "123.10.11.121",
"some_timeout": "13s"
}
}
}`
Expand All @@ -114,6 +118,7 @@ func TestDeeplyNestedJSON(t *testing.T) {
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something)
assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress)

}
Expand Down
42 changes: 42 additions & 0 deletions decoders/json/jsontypes/jsonduration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Package jsontypes contains helper types used by the JSON and Cue decoders to
// facilitate more natural decoding.
package jsontypes

import (
"bytes"
"encoding/json"
"fmt"
"time"
)

// ParsingDuration implements [encoding/json.Unmarshaler], supporting both
// quoted strings that are parseable with [time.ParseDuration], and integer nanoseconds if it's a number
type ParsingDuration int64

// UnmarshalJSON implements [encoding/json.Unmarshaler] for ParsingDuration.
func (p *ParsingDuration) UnmarshalJSON(b []byte) error {
d := json.NewDecoder(bytes.NewReader(b))
d.UseNumber()
n, tokErr := d.Token()
if tokErr != nil {
return fmt.Errorf("failed to parse token: %w", tokErr)
}
switch v := n.(type) {
case string:
dur, durParseErr := time.ParseDuration(v)
if durParseErr != nil {
return fmt.Errorf("failed to parse %q as duration: %w", v, durParseErr)
}
*p = ParsingDuration(dur)
return nil
case json.Number:
i, intParseErr := v.Int64()
if intParseErr != nil {
return fmt.Errorf("failed to parse number as integer nanoseconds: %w", intParseErr)
}
*p = ParsingDuration(i)
return nil
default:
return fmt.Errorf("unexpected JSON token-type %T; expected string or number", n)
}
}
88 changes: 88 additions & 0 deletions decoders/json/jsontypes/jsonduration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package jsontypes

import (
"encoding/json"
"reflect"
"testing"
"time"
)

func ptrVal[V any](v V) *V {
return &v
}

func TestParsingDuration(t *testing.T) {
t.Parallel()
type decStruct struct {
Dur ParsingDuration
DurPtr *ParsingDuration
}
for _, tbl := range []struct {
name string
inJSON string
expStruct any
expErr bool
}{
{
name: "integer_value_no_ptr",
inJSON: `{"dur": 1234}`,
expStruct: decStruct{Dur: 1234},
expErr: false,
},
{
name: "string_value_no_ptr",
inJSON: `{"dur": "3s"}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second)},
expErr: false,
},
{
name: "string_value_ptr",
inJSON: `{"dur": "3s", "durptr": "9m"}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(9 * time.Minute))},
expErr: false,
},
{
name: "integer_value_ptr",
inJSON: `{"dur": "3s", "durptr": 2048}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(2048))},
expErr: false,
},
{
name: "error_array_val",
inJSON: `{"dur": [], "durptr": 2048}`,
expErr: true,
},
{
name: "error_object_val",
inJSON: `{"dur": {}, "durptr": 2048}`,
expErr: true,
},
{
name: "error_unparsable_str_val",
inJSON: `{"dur": "sssssssssss", "durptr": 2048}`,
expErr: true,
},
{
name: "error_float_fractional_val",
inJSON: `{"dur": 0.333333, "durptr": 2048}`,
expErr: true,
},
} {
t.Run(tbl.name, func(t *testing.T) {
v := decStruct{}
decErr := json.Unmarshal([]byte(tbl.inJSON), &v)
if decErr != nil {
if !tbl.expErr {
t.Errorf("unexpected error unmarshaling: %s", decErr)
} else {
t.Logf("expected error: %s", decErr)
}
return
}
if !reflect.DeepEqual(v, tbl.expStruct) {
t.Errorf("unexpected value\n got: %+v\nwant:%+v", tbl.expStruct, v)
}
})

}
}
Loading