Skip to content

Commit 06a02f7

Browse files
committed
multi: enforce strict TLV length checks
This commit improves TLV decoding safety and consistency across multiple packages by enforcing fixed-length requirements and adding unit tests to prevent malformed TLV records from being accepted. Changes include: - lnwire: * Enforce 8-byte length in Fee TLV decoder. * Enforce PubNonceSize in Musig2Nonce TLV decoder. * Enforce 8-byte length in ShortChannelID TLV decoder. * Added roundtrip and invalid length tests for Fee, Musig2Nonce, and ShortChannelID records. - routing/route: * Enforce Vertex TLV length (33 bytes). * Added encode/decode and invalid length tests for Vertex. - tlv: * Enforce correct length in DBytes33 decoder (33 bytes). * Added tests ensuring all fixed-size primitive decoders reject incorrect TLV lengths. By strictly validating TLV lengths, we prevent malformed or corrupted TLV records from being silently accepted, improving protocol safety.
1 parent 0c2f045 commit 06a02f7

File tree

10 files changed

+224
-5
lines changed

10 files changed

+224
-5
lines changed

lnwire/musig2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func nonceTypeEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
5151
func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte,
5252
l uint64) error {
5353

54-
if v, ok := val.(*Musig2Nonce); ok {
54+
if v, ok := val.(*Musig2Nonce); ok && l == musig2.PubNonceSize {
5555
_, err := io.ReadFull(r, v[:])
5656
return err
5757
}

lnwire/musig2_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package lnwire
2+
3+
import (
4+
"testing"
5+
6+
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func makeNonce() Musig2Nonce {
11+
var n Musig2Nonce
12+
for i := range musig2.PubNonceSize {
13+
n[i] = byte(i)
14+
}
15+
16+
return n
17+
}
18+
19+
// TestMusig2NonceEncodeDecode tests that we're able to properly encode and
20+
// decode Musig2Nonce within TLV streams.
21+
func TestMusig2NonceEncodeDecode(t *testing.T) {
22+
t.Parallel()
23+
24+
nonce := makeNonce()
25+
26+
var extraData ExtraOpaqueData
27+
require.NoError(t, extraData.PackRecords(&nonce))
28+
29+
var extractedNonce Musig2Nonce
30+
_, err := extraData.ExtractRecords(&extractedNonce)
31+
require.NoError(t, err)
32+
33+
require.Equal(t, nonce, extractedNonce)
34+
}
35+
36+
// TestMusig2NonceTypeDecodeInvalidLength ensures that decoding a Musig2Nonce
37+
// TLV with an invalid length (anything other than 66 bytes) fails with an
38+
// error.
39+
func TestMusig2NonceTypeDecodeInvalidLength(t *testing.T) {
40+
t.Parallel()
41+
42+
nonce := makeNonce()
43+
44+
var extraData ExtraOpaqueData
45+
require.NoError(t, extraData.PackRecords(&nonce))
46+
47+
// Corrupt the TLV length field to simulate malformed input.
48+
extraData[1] = musig2.PubNonceSize + 1
49+
50+
var out Musig2Nonce
51+
_, err := extraData.ExtractRecords(&out)
52+
require.Error(t, err)
53+
54+
extraData[1] = musig2.PubNonceSize - 1
55+
56+
_, err = extraData.ExtractRecords(&out)
57+
require.Error(t, err)
58+
}

lnwire/short_channel_id.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ func DShortChannelID(r io.Reader, val interface{}, buf *[8]byte,
9292

9393
if v, ok := val.(*ShortChannelID); ok {
9494
var scid uint64
95-
err := tlv.DUint64(r, &scid, buf, 8)
95+
// tlv.DUint64 forces the length to be 8 bytes.
96+
err := tlv.DUint64(r, &scid, buf, l)
9697
if err != nil {
9798
return err
9899
}

lnwire/short_channel_id_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,28 @@ func TestScidTypeEncodeDecode(t *testing.T) {
6262
require.Contains(t, tlvs, AliasScidRecordType)
6363
require.Equal(t, aliasScid, aliasScid2)
6464
}
65+
66+
// TestScidTypeDecodeInvalidLength ensures that decoding a ShortChannelID TLV
67+
// with an invalid length (anything other than 8 bytes) fails with an error.
68+
func TestScidTypeDecodeInvalidLength(t *testing.T) {
69+
t.Parallel()
70+
71+
aliasScid := ShortChannelID{
72+
BlockHeight: 1, TxIndex: 1, TxPosition: 1,
73+
}
74+
75+
var extraData ExtraOpaqueData
76+
require.NoError(t, extraData.PackRecords(&aliasScid))
77+
78+
// Corrupt the TLV length field to simulate malformed input.
79+
extraData[1] = 8 + 1
80+
81+
var out ShortChannelID
82+
_, err := extraData.ExtractRecords(&out)
83+
require.Error(t, err)
84+
85+
extraData[1] = 8 - 1
86+
87+
_, err = extraData.ExtractRecords(&out)
88+
require.Error(t, err)
89+
}

lnwire/typed_fee.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func feeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
4141
// feeDecoder is a custom TLV decoder for the fee record.
4242
func feeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
4343
v, ok := val.(*Fee)
44-
if !ok {
44+
if !ok || l != 8 {
4545
return tlv.NewTypeForDecodingErr(val, "lnwire.Fee", l, 8)
4646
}
4747

lnwire/typed_fee_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,28 @@ func testTypedFee(t *testing.T, fee Fee) { //nolint: thelper
3838

3939
require.Equal(t, fee, extractedFee)
4040
}
41+
42+
// TestTypedFeeTypeDecodeInvalidLength ensures that decoding a Fee TLV
43+
// with an invalid length (anything other than 8 bytes) fails with an error.
44+
func TestTypedFeeTypeDecodeInvalidLength(t *testing.T) {
45+
t.Parallel()
46+
47+
fee := Fee{
48+
BaseFee: 1, FeeRate: 1,
49+
}
50+
51+
var extraData ExtraOpaqueData
52+
require.NoError(t, extraData.PackRecords(&fee))
53+
54+
// Corrupt the TLV length field to simulate malformed input.
55+
extraData[3] = 8 + 1
56+
57+
var out Fee
58+
_, err := extraData.ExtractRecords(&out)
59+
require.Error(t, err)
60+
61+
extraData[3] = 8 - 1
62+
63+
_, err = extraData.ExtractRecords(&out)
64+
require.Error(t, err)
65+
}

routing/route/route.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
112112
}
113113

114114
func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
115-
if b, ok := val.(*Vertex); ok {
115+
if b, ok := val.(*Vertex); ok && l == VertexSize {
116116
_, err := io.ReadFull(r, b[:])
117117
return err
118118
}

routing/route/route_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/btcsuite/btcd/btcec/v2"
99
"github.com/lightningnetwork/lnd/lnwire"
1010
"github.com/lightningnetwork/lnd/record"
11+
"github.com/lightningnetwork/lnd/tlv"
1112
"github.com/stretchr/testify/require"
1213
)
1314

@@ -430,3 +431,53 @@ func TestBlindedHopFee(t *testing.T) {
430431
require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(3))
431432
require.Equal(t, lnwire.MilliSatoshi(0), route.HopFee(4))
432433
}
434+
435+
func makeVertex() Vertex {
436+
var v Vertex
437+
for i := range VertexSize {
438+
v[i] = byte(i)
439+
}
440+
441+
return v
442+
}
443+
444+
// TestVertexTLVEncodeDecode tests that we're able to properly encode and decode
445+
// Vertex within TLV streams.
446+
func TestVertexTLVEncodeDecode(t *testing.T) {
447+
t.Parallel()
448+
449+
vertex := makeVertex()
450+
451+
var extraData lnwire.ExtraOpaqueData
452+
require.NoError(t, extraData.PackRecords(&vertex))
453+
454+
var vertex2 Vertex
455+
tlvs, err := extraData.ExtractRecords(&vertex2)
456+
require.NoError(t, err)
457+
458+
require.Contains(t, tlvs, tlv.Type(0))
459+
require.Equal(t, vertex, vertex2)
460+
}
461+
462+
// TestVertexTypeDecodeInvalidLength ensures that decoding a Vertex TLV
463+
// with an invalid length (anything other than 33) fails with an error.
464+
func TestVertexTypeDecodeInvalidLength(t *testing.T) {
465+
t.Parallel()
466+
467+
vertex := makeVertex()
468+
469+
var extraData lnwire.ExtraOpaqueData
470+
require.NoError(t, extraData.PackRecords(&vertex))
471+
472+
// Corrupt the TLV length field to simulate malformed input.
473+
extraData[1] = VertexSize + 1
474+
475+
var out Vertex
476+
_, err := extraData.ExtractRecords(&out)
477+
require.Error(t, err)
478+
479+
extraData[1] = VertexSize - 1
480+
481+
_, err = extraData.ExtractRecords(&out)
482+
require.Error(t, err)
483+
}

tlv/primitive.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error {
257257
// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not
258258
// a *[33]byte.
259259
func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
260-
if b, ok := val.(*[33]byte); ok {
260+
if b, ok := val.(*[33]byte); ok && l == 33 {
261261
_, err := io.ReadFull(r, b[:])
262262
return err
263263
}

tlv/primitive_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,62 @@ func TestPrimitiveEncodings(t *testing.T) {
253253
prim, prim2)
254254
}
255255
}
256+
257+
// TestPrimitiveWrongLength asserts that fixed-size primitive decoders fail
258+
// with ErrTypeForDecoding when given an incorrect TLV length.
259+
func TestPrimitiveWrongLength(t *testing.T) {
260+
prim := primitive{
261+
u8: 0x01,
262+
u16: 0x0201,
263+
u32: 0x02000001,
264+
u64: 0x0200000000000001,
265+
b32: [32]byte{0x02, 0x01},
266+
b33: [33]byte{0x03, 0x01},
267+
b64: [64]byte{0x02, 0x01},
268+
pk: testPK,
269+
boolean: true,
270+
}
271+
272+
type item struct {
273+
enc fieldEncoder
274+
dec fieldDecoder
275+
}
276+
277+
items := []item{
278+
{fieldEncoder{&prim.u8, tlv.EUint8}, fieldDecoder{new(byte), tlv.DUint8, 1}},
279+
{fieldEncoder{&prim.u16, tlv.EUint16}, fieldDecoder{new(uint16), tlv.DUint16, 2}},
280+
{fieldEncoder{&prim.u32, tlv.EUint32}, fieldDecoder{new(uint32), tlv.DUint32, 4}},
281+
{fieldEncoder{&prim.u64, tlv.EUint64}, fieldDecoder{new(uint64), tlv.DUint64, 8}},
282+
{fieldEncoder{&prim.b32, tlv.EBytes32}, fieldDecoder{new([32]byte), tlv.DBytes32, 32}},
283+
{fieldEncoder{&prim.b33, tlv.EBytes33}, fieldDecoder{new([33]byte), tlv.DBytes33, 33}},
284+
{fieldEncoder{&prim.b64, tlv.EBytes64}, fieldDecoder{new([64]byte), tlv.DBytes64, 64}},
285+
{fieldEncoder{&prim.pk, tlv.EPubKey}, fieldDecoder{new(*btcec.PublicKey), tlv.DPubKey, 33}},
286+
{fieldEncoder{&prim.boolean, tlv.EBool}, fieldDecoder{new(bool), tlv.DBool, 1}},
287+
}
288+
289+
for _, it := range items {
290+
var buf [8]byte
291+
var b bytes.Buffer
292+
if err := it.enc.encoder(&b, it.enc.val, &buf); err != nil {
293+
t.Fatalf("encode %T: %v", it.enc.val, err)
294+
}
295+
data := b.Bytes()
296+
297+
// Generate two wrong lengths: expected-1 (if >0) and expected+1.
298+
wrongs := []uint64{it.dec.size + 1}
299+
if it.dec.size > 0 {
300+
wrongs = append(wrongs, it.dec.size-1)
301+
}
302+
303+
for _, l := range wrongs {
304+
r := bytes.NewReader(data)
305+
if err := it.dec.decoder(r, it.dec.val, &buf, l); err == nil {
306+
t.Fatalf("decoder %T accepted wrong length %d (expected %d)", it.dec.decoder, l, it.dec.size)
307+
} else {
308+
if _, ok := err.(tlv.ErrTypeForDecoding); !ok {
309+
t.Fatalf("expected ErrTypeForDecoding, got %T: %v", err, err)
310+
}
311+
}
312+
}
313+
}
314+
}

0 commit comments

Comments
 (0)