Skip to content

Commit

Permalink
Merge pull request #92 from vimeo/parsingduration_json_cue
Browse files Browse the repository at this point in the history
json/cue: substitute time.Duration with type using time.ParseDuration
  • Loading branch information
dfinkel authored May 20, 2024
2 parents e43afea + fe58cd1 commit 875c7b5
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 12 deletions.
24 changes: 23 additions & 1 deletion decoders/cue/cue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,32 @@ import (
"fmt"
"io"
"reflect"
"time"

"cuelang.org/go/cue/cuecontext"

"github.com/vimeo/dials"
"github.com/vimeo/dials/common"
"github.com/vimeo/dials/decoders/json/jsontypes"
"github.com/vimeo/dials/tagformat"
"github.com/vimeo/dials/transform"
)

// Decoder is a decoder that knows how to work with configs written in Cue
type Decoder struct{}

func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at
// package-scope, so we don't have to construct a new one every time Decode is
// called.
var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]())

// Decode is a decoder that decodes the Cue config from an io.Reader into the
// appropriate struct.
func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
Expand All @@ -27,7 +41,9 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
const jsonTagName = "json"

// If there aren't any json tags, copy over from any dials tags.
// Also, convert any time.Duration fields to jsontypes.ParsingDuration so we can decode those values as strings.
tfmr := transform.NewTransformer(t.Type(),
parsingDurMangler,
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: jsonTagName})
reflVal, tfmErr := tfmr.Translate()
Expand All @@ -43,5 +59,11 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
if decErr := val.Decode(reflVal.Addr().Interface()); decErr != nil {
return reflect.Value{}, fmt.Errorf("failed to decode cue value into dials struct: %w", decErr)
}
return reflVal, nil

unmangledVal, unmangleErr := tfmr.ReverseTranslate(reflVal)
if unmangleErr != nil {
return reflect.Value{}, unmangleErr
}

return unmangledVal, nil
}
21 changes: 16 additions & 5 deletions decoders/cue/cue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -82,13 +83,17 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
SomeTimeout time.Duration `dials:"some_timeout"`
SomeOtherTimeout time.Duration `dials:"some_other_timeout"`
SomeLifetime time.Duration `dials:"some_lifetime_ns"`
} `dials:"other_stuff"`
} `dials:"database_user"`
}

jsonData := `{
cueData := `
import "time"
"database_name": "something",
"database_address": "127.0.0.1",
"database_user": {
Expand All @@ -97,15 +102,18 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
"other_stuff": {
"something": "asdf",
"ip_address": "123.10.11.121"
"some_timeout": "13s"
"some_other_timeout": 87 * time.Second,
"some_lifetime_ns": 378,
}
}
}`
`

myConfig := &testConfig{}
d, err := dials.Config(
context.Background(),
myConfig,
&static.StringSource{Data: jsonData, Decoder: &Decoder{}},
&static.StringSource{Data: cueData, Decoder: &Decoder{}},
)
require.NoError(t, err)

Expand All @@ -115,6 +123,9 @@ func TestDeeplyNestedCueJSON(t *testing.T) {
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something)
assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout)
assert.Equal(t, time.Second*87, c.DatabaseUser.OtherStuff.SomeOtherTimeout)
assert.Equal(t, time.Nanosecond*378, c.DatabaseUser.OtherStuff.SomeLifetime)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress)

}
Expand Down
19 changes: 16 additions & 3 deletions decoders/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"reflect"
"time"

"github.com/vimeo/dials"
"github.com/vimeo/dials/common"
"github.com/vimeo/dials/decoders/json/jsontypes"
"github.com/vimeo/dials/tagformat"
"github.com/vimeo/dials/transform"
)
Expand All @@ -17,19 +18,31 @@ import (
const JSONTagName = "json"

// Decoder is a decoder that knows how to work with text encoded in JSON
type Decoder struct {
type Decoder struct{}

func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}

// pre-declare the time.Duration -> jsontypes.ParsingDuration mangler at
// package-scope, so we don't have to construct a new one every time Decode is
// called.
var parsingDurMangler = must(transform.NewSingleTypeSubstitutionMangler[time.Duration, jsontypes.ParsingDuration]())

// Decode is a decoder that decodes the JSON from an io.Reader into the
// appropriate struct.
func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
jsonBytes, err := ioutil.ReadAll(r)
jsonBytes, err := io.ReadAll(r)
if err != nil {
return reflect.Value{}, fmt.Errorf("error reading JSON: %s", err)
}

// If there aren't any json tags, copy over from any dials tags.
tfmr := transform.NewTransformer(t.Type(),
parsingDurMangler,
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: JSONTagName})
val, tfmErr := tfmr.Translate()
Expand Down
11 changes: 8 additions & 3 deletions decoders/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/vimeo/dials"
"github.com/vimeo/dials/sources/static"
)
Expand Down Expand Up @@ -81,8 +83,9 @@ func TestDeeplyNestedJSON(t *testing.T) {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
SomeTimeout time.Duration `dials:"some_timeout"`
} `dials:"other_stuff"`
} `dials:"database_user"`
}
Expand All @@ -95,7 +98,8 @@ func TestDeeplyNestedJSON(t *testing.T) {
"password": "password",
"other_stuff": {
"something": "asdf",
"ip_address": "123.10.11.121"
"ip_address": "123.10.11.121",
"some_timeout": "13s"
}
}
}`
Expand All @@ -114,6 +118,7 @@ func TestDeeplyNestedJSON(t *testing.T) {
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.OtherStuff.Something)
assert.Equal(t, time.Second*13, c.DatabaseUser.OtherStuff.SomeTimeout)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.OtherStuff.IPAddress)

}
Expand Down
42 changes: 42 additions & 0 deletions decoders/json/jsontypes/jsonduration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Package jsontypes contains helper types used by the JSON and Cue decoders to
// facilitate more natural decoding.
package jsontypes

import (
"bytes"
"encoding/json"
"fmt"
"time"
)

// ParsingDuration implements [encoding/json.Unmarshaler], supporting both
// quoted strings that are parseable with [time.ParseDuration], and integer nanoseconds if it's a number
type ParsingDuration int64

// UnmarshalJSON implements [encoding/json.Unmarshaler] for ParsingDuration.
func (p *ParsingDuration) UnmarshalJSON(b []byte) error {
d := json.NewDecoder(bytes.NewReader(b))
d.UseNumber()
n, tokErr := d.Token()
if tokErr != nil {
return fmt.Errorf("failed to parse token: %w", tokErr)
}
switch v := n.(type) {
case string:
dur, durParseErr := time.ParseDuration(v)
if durParseErr != nil {
return fmt.Errorf("failed to parse %q as duration: %w", v, durParseErr)
}
*p = ParsingDuration(dur)
return nil
case json.Number:
i, intParseErr := v.Int64()
if intParseErr != nil {
return fmt.Errorf("failed to parse number as integer nanoseconds: %w", intParseErr)
}
*p = ParsingDuration(i)
return nil
default:
return fmt.Errorf("unexpected JSON token-type %T; expected string or number", n)
}
}
88 changes: 88 additions & 0 deletions decoders/json/jsontypes/jsonduration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package jsontypes

import (
"encoding/json"
"reflect"
"testing"
"time"
)

func ptrVal[V any](v V) *V {
return &v
}

func TestParsingDuration(t *testing.T) {
t.Parallel()
type decStruct struct {
Dur ParsingDuration
DurPtr *ParsingDuration
}
for _, tbl := range []struct {
name string
inJSON string
expStruct any
expErr bool
}{
{
name: "integer_value_no_ptr",
inJSON: `{"dur": 1234}`,
expStruct: decStruct{Dur: 1234},
expErr: false,
},
{
name: "string_value_no_ptr",
inJSON: `{"dur": "3s"}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second)},
expErr: false,
},
{
name: "string_value_ptr",
inJSON: `{"dur": "3s", "durptr": "9m"}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(9 * time.Minute))},
expErr: false,
},
{
name: "integer_value_ptr",
inJSON: `{"dur": "3s", "durptr": 2048}`,
expStruct: decStruct{Dur: ParsingDuration(3 * time.Second), DurPtr: ptrVal(ParsingDuration(2048))},
expErr: false,
},
{
name: "error_array_val",
inJSON: `{"dur": [], "durptr": 2048}`,
expErr: true,
},
{
name: "error_object_val",
inJSON: `{"dur": {}, "durptr": 2048}`,
expErr: true,
},
{
name: "error_unparsable_str_val",
inJSON: `{"dur": "sssssssssss", "durptr": 2048}`,
expErr: true,
},
{
name: "error_float_fractional_val",
inJSON: `{"dur": 0.333333, "durptr": 2048}`,
expErr: true,
},
} {
t.Run(tbl.name, func(t *testing.T) {
v := decStruct{}
decErr := json.Unmarshal([]byte(tbl.inJSON), &v)
if decErr != nil {
if !tbl.expErr {
t.Errorf("unexpected error unmarshaling: %s", decErr)
} else {
t.Logf("expected error: %s", decErr)
}
return
}
if !reflect.DeepEqual(v, tbl.expStruct) {
t.Errorf("unexpected value\n got: %+v\nwant:%+v", tbl.expStruct, v)
}
})

}
}

0 comments on commit 875c7b5

Please sign in to comment.