From c38ae1ad44a1cb99cbcce65a1fb26bec13760b13 Mon Sep 17 00:00:00 2001 From: Mike Date: Fri, 5 Jul 2024 14:40:19 -0400 Subject: [PATCH] wire: add struct length prefix struct tag Add a marshal/unmarshal struct tag that indicates whether to handle a length prefix for structs payloads. --- wire/decode.go | 92 ++++++++++++++++++-------------------- wire/decode_test.go | 93 +++++++++++++++++++++++++++++++++++++++ wire/encode.go | 80 ++++++++++++++++++--------------- wire/encode_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 285 insertions(+), 85 deletions(-) diff --git a/wire/decode.go b/wire/decode.go index 7b6a07ba..62eb4725 100644 --- a/wire/decode.go +++ b/wire/decode.go @@ -18,6 +18,19 @@ func Unmarshal(v any, r io.Reader) error { func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Reader) error { switch v.Kind() { case reflect.Struct: + if lenTag, ok := tag.Lookup("len_prefix"); ok { + bufLen, err := readUnsignedInt(lenTag, r) + if err != nil { + return err + } + b := make([]byte, bufLen) + if bufLen > 0 { + if _, err := io.ReadFull(r, b); err != nil { + return err + } + } + r = bytes.NewBuffer(b) + } for i := 0; i < v.NumField(); i++ { if err := unmarshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, r); err != nil { return err @@ -27,21 +40,9 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read case reflect.String: var bufLen int if lenTag, ok := tag.Lookup("len_prefix"); ok { - switch lenTag { - case "uint8": - var l uint8 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - bufLen = int(l) - case "uint16": - var l uint16 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - bufLen = int(l) - default: - return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag) + var err error + if bufLen, err = readUnsignedInt(lenTag, r); err != nil { + return err } } else { return fmt.Errorf("%w: missing len_prefix tag", ErrUnmarshalFailure) @@ -85,24 +86,10 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read return nil case reflect.Slice: if lenTag, ok := tag.Lookup("len_prefix"); ok { - var bufLen int - switch lenTag { - case "uint8": - var l uint8 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - bufLen = int(l) - case "uint16": - var l uint16 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - bufLen = int(l) - default: - return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag) + bufLen, err := readUnsignedInt(lenTag, r) + if err != nil { + return err } - buf := make([]byte, bufLen) if bufLen > 0 { if _, err := io.ReadFull(r, buf); err != nil { @@ -123,24 +110,10 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read } v.Set(slice) } else if countTag, ok := tag.Lookup("count_prefix"); ok { - var count int - switch countTag { - case "uint8": - var l uint8 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - count = int(l) - case "uint16": - var l uint16 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return err - } - count = int(l) - default: - return fmt.Errorf("%w: unsupported count_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, lenTag) + count, err := readUnsignedInt(countTag, r) + if err != nil { + return err } - slice := reflect.New(v.Type()).Elem() for i := 0; i < count; i++ { v1 := reflect.New(v.Type().Elem()).Interface() @@ -169,3 +142,24 @@ func unmarshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, r io.Read return fmt.Errorf("%w: unsupported type %v", ErrUnmarshalFailure, t.Kind()) } } + +func readUnsignedInt(intType string, r io.Reader) (int, error) { + var bufLen int + switch intType { + case "uint8": + var l uint8 + if err := binary.Read(r, binary.BigEndian, &l); err != nil { + return 0, err + } + bufLen = int(l) + case "uint16": + var l uint16 + if err := binary.Read(r, binary.BigEndian, &l); err != nil { + return 0, err + } + bufLen = int(l) + default: + return 0, fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrUnmarshalFailure, intType) + } + return bufLen, nil +} diff --git a/wire/decode_test.go b/wire/decode_test.go index 26fa3d62..f742258d 100644 --- a/wire/decode_test.go +++ b/wire/decode_test.go @@ -349,6 +349,99 @@ func TestUnmarshal(t *testing.T) { []byte{0x0, 0x02}, /* count prefix */ []byte{0x0, 0xa, 0x0, 0x2, 0x4, 0xd2, 0x0, 0x14, 0x0, 0x2, 0x4, 0xd2}...), /* slice val */ }, + { + name: "struct with uint8 len_prefix", + prototype: &struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint8"` + Val4 uint16 + }{}, + want: &struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint8"` + Val4 uint16 + }{ + Val0: 34, + Val1: struct { + Val2 uint16 + Val3 uint8 + }{ + Val2: 16, + Val3: 10, + }, + Val4: 32, + }, + given: []byte{ + 0x22, // Val0 + 0x03, // Val1 struct len + 0x00, 0x10, // Val2 + 0x0A, // Val3 + 0x00, 0x20, // Val2 + }, + }, + { + name: "struct with uint16 len_prefix", + prototype: &struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint16"` + Val4 uint16 + }{}, + want: &struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint16"` + Val4 uint16 + }{ + Val0: 34, + Val1: struct { + Val2 uint16 + Val3 uint8 + }{ + Val2: 16, + Val3: 10, + }, + Val4: 32, + }, + given: []byte{ + 0x22, // Val0 + 0x00, 0x03, // Val1 struct len + 0x00, 0x10, // Val2 + 0x0A, // Val3 + 0x00, 0x20, // Val2 + }, + }, + { + name: "struct with uint16 len_prefix with read error", + prototype: &struct { + Val1 struct { + Val2 uint16 + } `len_prefix:"uint16"` + }{}, + given: []byte{ + 0x00, 0x10, // 16 byte len, but the body is truncated + }, + wantErr: io.EOF, + }, + { + name: "struct with unknown len_prefix", + prototype: &struct { + Val1 struct { + Val2 uint16 + } `len_prefix:"uint128"` + }{}, + wantErr: ErrUnmarshalFailure, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/wire/encode.go b/wire/encode.go index 4b7e1010..51b0f127 100644 --- a/wire/encode.go +++ b/wire/encode.go @@ -22,25 +22,35 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer } switch t.Kind() { case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - if err := marshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, w); err != nil { + marshalEachField := func(w io.Writer) error { + for i := 0; i < t.NumField(); i++ { + if err := marshal(t.Field(i).Type, v.Field(i), t.Field(i).Tag, w); err != nil { + return err + } + } + return nil + } + if lenTag, ok := tag.Lookup("len_prefix"); ok { + buf := &bytes.Buffer{} + if err := marshalEachField(buf); err != nil { + return err + } + // write struct length + if err := writeUnsignedInt(lenTag, buf.Len(), w); err != nil { + return err + } + // write struct bytes + if buf.Len() > 0 { + _, err := w.Write(buf.Bytes()) return err } + return nil } - return nil + return marshalEachField(w) case reflect.String: if lenTag, ok := tag.Lookup("len_prefix"); ok { - switch lenTag { - case "uint8": - if err := binary.Write(w, binary.BigEndian, uint8(len(v.String()))); err != nil { - return err - } - case "uint16": - if err := binary.Write(w, binary.BigEndian, uint16(len(v.String()))); err != nil { - return err - } - default: - return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, lenTag) + if err := writeUnsignedInt(lenTag, len(v.String()), w); err != nil { + return err } } return binary.Write(w, binary.BigEndian, []byte(v.String())) @@ -63,34 +73,16 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer var hasLenPrefix bool if l, ok := tag.Lookup("len_prefix"); ok { hasLenPrefix = true - switch l { - case "uint8": - if err := binary.Write(w, binary.BigEndian, uint8(buf.Len())); err != nil { - return err - } - case "uint16": - if err := binary.Write(w, binary.BigEndian, uint16(buf.Len())); err != nil { - return err - } - default: - return fmt.Errorf("%w: unsupported len_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, l) + if err := writeUnsignedInt(l, buf.Len(), w); err != nil { + return err } } if l, ok := tag.Lookup("count_prefix"); ok { if hasLenPrefix { return fmt.Errorf("%w: struct elem has both len_prefix and count_prefix: ", ErrMarshalFailure) } - switch l { - case "uint8": - if err := binary.Write(w, binary.BigEndian, uint8(v.Len())); err != nil { - return err - } - case "uint16": - if err := binary.Write(w, binary.BigEndian, uint16(v.Len())); err != nil { - return err - } - default: - return fmt.Errorf("%w: unsupported count_prefix type %s. allowed types: uint8, uint16", ErrMarshalFailure, l) + if err := writeUnsignedInt(l, v.Len(), w); err != nil { + return err } } if buf.Len() > 0 { @@ -104,3 +96,19 @@ func marshal(t reflect.Type, v reflect.Value, tag reflect.StructTag, w io.Writer return fmt.Errorf("%w: unsupported type %v", ErrMarshalFailure, t.Kind()) } } + +func writeUnsignedInt(intType string, intVal int, w io.Writer) error { + switch intType { + case "uint8": + if err := binary.Write(w, binary.BigEndian, uint8(intVal)); err != nil { + return err + } + case "uint16": + if err := binary.Write(w, binary.BigEndian, uint16(intVal)); err != nil { + return err + } + default: + return fmt.Errorf("%w: unsupported type %s. allowed types: uint8, uint16", ErrMarshalFailure, intType) + } + return nil +} diff --git a/wire/encode_test.go b/wire/encode_test.go index e34bfcc2..76ff47f9 100644 --- a/wire/encode_test.go +++ b/wire/encode_test.go @@ -329,6 +329,111 @@ func TestMarshal(t *testing.T) { given: nil, wantErr: ErrMarshalFailureNilSNAC, }, + { + name: "struct with uint8 len_prefix", + w: &bytes.Buffer{}, + given: struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint8"` + Val4 uint16 + }{ + Val0: 34, + Val1: struct { + Val2 uint16 + Val3 uint8 + }{ + Val2: 16, + Val3: 10, + }, + Val4: 32, + }, + want: append( + []byte{ + 0x22, // Val0 + 0x03, // Val1 struct len + 0x00, 0x10, // Val2 + 0x0A, // Val3 + 0x00, 0x20, // Val2 + }), + }, + { + name: "struct with uint16 len_prefix", + w: &bytes.Buffer{}, + given: struct { + Val0 uint8 + Val1 struct { + Val2 uint16 + Val3 uint8 + } `len_prefix:"uint16"` + Val4 uint16 + }{ + Val0: 34, + Val1: struct { + Val2 uint16 + Val3 uint8 + }{ + Val2: 16, + Val3: 10, + }, + Val4: 32, + }, + want: []byte{ + 0x22, // Val0 + 0x00, 0x03, // Val1 struct len + 0x00, 0x10, // Val2 + 0x0A, // Val3 + 0x00, 0x20, // Val2 + }, + }, + { + name: "invalid struct with uint16 len_prefix", + w: &bytes.Buffer{}, + given: struct { + Val1 struct { + Val2 int + } `len_prefix:"uint16"` + }{ + Val1: struct { + Val2 int + }{ + Val2: 16, + }, + }, + wantErr: ErrMarshalFailure, + }, + { + name: "empty struct with uint16 len_prefix", + w: &bytes.Buffer{}, + given: struct { + Val1 struct { + } `len_prefix:"uint16"` + }{ + Val1: struct { + }{}, + }, + want: []byte{ + 0x00, 0x00, // 0-len + }, + }, + { + name: "struct with unknown len_prefix", + w: &bytes.Buffer{}, + given: struct { + Val1 struct { + Val2 uint16 + } `len_prefix:"uint128"` + }{ + Val1: struct { + Val2 uint16 + }{ + Val2: 16, + }, + }, + wantErr: ErrMarshalFailure, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {