Skip to content

Commit

Permalink
wire: add struct length prefix struct tag
Browse files Browse the repository at this point in the history
Add a marshal/unmarshal struct tag that indicates whether to handle a
length prefix for structs payloads.
  • Loading branch information
mk6i committed Jul 8, 2024
1 parent 5241751 commit c38ae1a
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 85 deletions.
92 changes: 43 additions & 49 deletions wire/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ func Unmarshal(v any, r io.Reader) error {
func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Reader) error {
switch v.Kind() {
case reflect.Struct:
if lenTag, ok := tag.Lookup("len_prefix"); ok {
bufLen, err := readUnsignedInt(lenTag, r)
if err != nil {
return err
}
b := make([]byte, bufLen)
if bufLen > 0 {
if _, err := io.ReadFull(r, b); err != nil {
return err
}
}
r = bytes.NewBuffer(b)
}
for i := 0; i < v.NumField(); i++ {
if err := unmarshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, r); err != nil {
return err
Expand All @@ -27,21 +40,9 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read
case reflect.String:
var bufLen int
if lenTag, ok := tag.Lookup("len_prefix"); ok {
switch lenTag {
case "uint8":
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
bufLen = int(l)
case "uint16":
var l uint16
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
bufLen = int(l)
default:
return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag)
var err error
if bufLen, err = readUnsignedInt(lenTag, r); err != nil {
return err
}
} else {
return fmt.Errorf("%w: missing len_prefix tag", ErrUnmarshalFailure)
Expand Down Expand Up @@ -85,24 +86,10 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read
return nil
case reflect.Slice:
if lenTag, ok := tag.Lookup("len_prefix"); ok {
var bufLen int
switch lenTag {
case "uint8":
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
bufLen = int(l)
case "uint16":
var l uint16
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
bufLen = int(l)
default:
return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag)
bufLen, err := readUnsignedInt(lenTag, r)
if err != nil {
return err
}

buf := make([]byte, bufLen)
if bufLen > 0 {
if _, err := io.ReadFull(r, buf); err != nil {
Expand All @@ -123,24 +110,10 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read
}
v.Set(slice)
} else if countTag, ok := tag.Lookup("count_prefix"); ok {
var count int
switch countTag {
case "uint8":
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
count = int(l)
case "uint16":
var l uint16
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return err
}
count = int(l)
default:
return fmt.Errorf("%w: unsupported count_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag)
count, err := readUnsignedInt(countTag, r)
if err != nil {
return err
}

slice := reflect.New(v.Type()).Elem()
for i := 0; i < count; i++ {
v1 := reflect.New(v.Type().Elem()).Interface()
Expand Down Expand Up @@ -169,3 +142,24 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read
return fmt.Errorf("%w: unsupported type %v", ErrUnmarshalFailure, t.Kind())
}
}

func readUnsignedInt(intType string, r io.Reader) (int, error) {
var bufLen int
switch intType {
case "uint8":
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return 0, err
}
bufLen = int(l)
case "uint16":
var l uint16
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return 0, err
}
bufLen = int(l)
default:
return 0, fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, intType)
}
return bufLen, nil
}
93 changes: 93 additions & 0 deletions wire/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,99 @@ func TestUnmarshal(t *testing.T) {
[]byte{0x0, 0x02}, /* count prefix */
[]byte{0x0, 0xa, 0x0, 0x2, 0x4, 0xd2, 0x0, 0x14, 0x0, 0x2, 0x4, 0xd2}...), /* slice val */
},
{
name: "struct with uint8 len_prefix",
prototype: &struct {
Val0 uint8
Val1 struct {
Val2 uint16
Val3 uint8
} `len_prefix:"uint8"`
Val4 uint16
}{},
want: &struct {
Val0 uint8
Val1 struct {
Val2 uint16
Val3 uint8
} `len_prefix:"uint8"`
Val4 uint16
}{
Val0: 34,
Val1: struct {
Val2 uint16
Val3 uint8
}{
Val2: 16,
Val3: 10,
},
Val4: 32,
},
given: []byte{
0x22, // Val0
0x03, // Val1 struct len
0x00, 0x10, // Val2
0x0A, // Val3
0x00, 0x20, // Val2
},
},
{
name: "struct with uint16 len_prefix",
prototype: &struct {
Val0 uint8
Val1 struct {
Val2 uint16
Val3 uint8
} `len_prefix:"uint16"`
Val4 uint16
}{},
want: &struct {
Val0 uint8
Val1 struct {
Val2 uint16
Val3 uint8
} `len_prefix:"uint16"`
Val4 uint16
}{
Val0: 34,
Val1: struct {
Val2 uint16
Val3 uint8
}{
Val2: 16,
Val3: 10,
},
Val4: 32,
},
given: []byte{
0x22, // Val0
0x00, 0x03, // Val1 struct len
0x00, 0x10, // Val2
0x0A, // Val3
0x00, 0x20, // Val2
},
},
{
name: "struct with uint16 len_prefix with read error",
prototype: &struct {
Val1 struct {
Val2 uint16
} `len_prefix:"uint16"`
}{},
given: []byte{
0x00, 0x10, // 16 byte len, but the body is truncated
},
wantErr: io.EOF,
},
{
name: "struct with unknown len_prefix",
prototype: &struct {
Val1 struct {
Val2 uint16
} `len_prefix:"uint128"`
}{},
wantErr: ErrUnmarshalFailure,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
80 changes: 44 additions & 36 deletions wire/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,35 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer
}
switch t.Kind() {
case reflect.Struct:
for i := 0; i < t.NumField(); i++ {
if err := marshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, w); err != nil {
marshalEachField := func(w io.Writer) error {
for i := 0; i < t.NumField(); i++ {
if err := marshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, w); err != nil {
return err
}
}
return nil
}
if lenTag, ok := tag.Lookup("len_prefix"); ok {
buf := &bytes.Buffer{}
if err := marshalEachField(buf); err != nil {
return err
}
// write struct length
if err := writeUnsignedInt(lenTag, buf.Len(), w); err != nil {
return err
}
// write struct bytes
if buf.Len() > 0 {
_, err := w.Write(buf.Bytes())
return err
}
return nil
}
return nil
return marshalEachField(w)
case reflect.String:
if lenTag, ok := tag.Lookup("len_prefix"); ok {
switch lenTag {
case "uint8":
if err := binary.Write(w, binary.BigEndian, uint8(len(v.String()))); err != nil {
return err
}
case "uint16":
if err := binary.Write(w, binary.BigEndian, uint16(len(v.String()))); err != nil {
return err
}
default:
return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, lenTag)
if err := writeUnsignedInt(lenTag, len(v.String()), w); err != nil {
return err
}
}
return binary.Write(w, binary.BigEndian, []byte(v.String()))
Expand All @@ -63,34 +73,16 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer
var hasLenPrefix bool
if l, ok := tag.Lookup("len_prefix"); ok {
hasLenPrefix = true
switch l {
case "uint8":
if err := binary.Write(w, binary.BigEndian, uint8(buf.Len())); err != nil {
return err
}
case "uint16":
if err := binary.Write(w, binary.BigEndian, uint16(buf.Len())); err != nil {
return err
}
default:
return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, l)
if err := writeUnsignedInt(l, buf.Len(), w); err != nil {
return err
}
}
if l, ok := tag.Lookup("count_prefix"); ok {
if hasLenPrefix {
return fmt.Errorf("%w: struct elem has both len_prefix and count_prefix: ", ErrMarshalFailure)
}
switch l {
case "uint8":
if err := binary.Write(w, binary.BigEndian, uint8(v.Len())); err != nil {
return err
}
case "uint16":
if err := binary.Write(w, binary.BigEndian, uint16(v.Len())); err != nil {
return err
}
default:
return fmt.Errorf("%w: unsupported count_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, l)
if err := writeUnsignedInt(l, v.Len(), w); err != nil {
return err
}
}
if buf.Len() > 0 {
Expand All @@ -104,3 +96,19 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer
return fmt.Errorf("%w: unsupported type %v", ErrMarshalFailure, t.Kind())
}
}

func writeUnsignedInt(intType string, intVal int, w io.Writer) error {
switch intType {
case "uint8":
if err := binary.Write(w, binary.BigEndian, uint8(intVal)); err != nil {
return err
}
case "uint16":
if err := binary.Write(w, binary.BigEndian, uint16(intVal)); err != nil {
return err
}
default:
return fmt.Errorf("%w: unsupported type %s. allowed types: uint8, uint16", ErrMarshalFailure, intType)
}
return nil
}
Loading

0 comments on commit c38ae1a

Please sign in to comment.