Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/signer/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ import "errors"
var (
// ErrInvalidPrivateKey is returned when a private key cannot be decoded or parsed.
ErrInvalidPrivateKey = errors.New("invalid private key")

// ErrInvalidSignature is returned when a signature has invalid components.
ErrInvalidSignature = errors.New("invalid signature")
)
20 changes: 20 additions & 0 deletions pkg/signer/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,28 @@ func (s *Signer) SignData(data []byte) (*Signature, error) {
return s.Sign(hash)
}

// maxScalarBytes is the maximum byte length for secp256k1 scalar values (R and S).
// Scalars must fit within 32 bytes (256 bits) to be valid signature components.
const maxScalarBytes = 32

// RecoverAddress recovers the address that signed the given hash with the given signature.
func RecoverAddress(hash common.Hash, sig *Signature) (common.Address, error) {
if sig == nil {
return common.Address{}, fmt.Errorf("%w: signature is nil", ErrInvalidSignature)
}
if sig.R == nil || sig.S == nil {
return common.Address{}, fmt.Errorf("%w: R or S is nil", ErrInvalidSignature)
}

rBytes := sig.R.Bytes()
if len(rBytes) > maxScalarBytes {
return common.Address{}, fmt.Errorf("%w: R exceeds %d bytes (got %d)", ErrInvalidSignature, maxScalarBytes, len(rBytes))
}
sBytes := sig.S.Bytes()
if len(sBytes) > maxScalarBytes {
return common.Address{}, fmt.Errorf("%w: S exceeds %d bytes (got %d)", ErrInvalidSignature, maxScalarBytes, len(sBytes))
}

sigBytes := make([]byte, 65)
sig.R.FillBytes(sigBytes[0:32])
sig.S.FillBytes(sigBytes[32:64])
Expand Down
70 changes: 70 additions & 0 deletions pkg/signer/signer_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package signer

import (
"errors"
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -114,3 +117,70 @@ func TestRecoverAddress(t *testing.T) {

assert.Equal(t, sgn.Address(), recoveredAddr)
}

func TestRecoverAddress_InvalidSignatures(t *testing.T) {
big33Bytes := new(big.Int).Lsh(big.NewInt(1), 256) // 2^256 requires 33 bytes

tests := []struct {
name string
sig *Signature
wantErr bool
wantErrStr string
}{
{
name: "nil signature",
sig: nil,
wantErr: true,
wantErrStr: "signature is nil",
},
{
name: "nil R",
sig: &Signature{R: nil, S: big.NewInt(1), YParity: 0},
wantErr: true,
wantErrStr: "nil",
},
{
name: "nil S",
sig: &Signature{R: big.NewInt(1), S: nil, YParity: 0},
wantErr: true,
wantErrStr: "nil",
},
{
name: "nil R and S",
sig: &Signature{R: nil, S: nil, YParity: 0},
wantErr: true,
wantErrStr: "nil",
},
{
name: "oversized R (33 bytes)",
sig: &Signature{R: big33Bytes, S: big.NewInt(1), YParity: 0},
wantErr: true,
wantErrStr: "R exceeds",
},
{
name: "oversized S (33 bytes)",
sig: &Signature{R: big.NewInt(1), S: big33Bytes, YParity: 0},
wantErr: true,
wantErrStr: "S exceeds",
},
{
name: "oversized R and S",
sig: &Signature{R: big33Bytes, S: big33Bytes, YParity: 0},
wantErr: true,
wantErrStr: "R exceeds",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := RecoverAddress(common.Hash{}, tt.sig)
if tt.wantErr {
assert.Error(t, err)
assert.True(t, errors.Is(err, ErrInvalidSignature))
assert.Contains(t, err.Error(), tt.wantErrStr)
} else {
assert.NoError(t, err)
}
})
}
}
13 changes: 13 additions & 0 deletions pkg/transaction/deserialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ func decodeAccessList(accessListRaw []interface{}) (AccessList, error) {
return accessList, nil
}

// maxSignatureScalarBytes is the maximum byte length for secp256k1 signature scalars (R and S).
// Valid signature components must fit within 32 bytes (256 bits).
const maxSignatureScalarBytes = 32

// decodeSignature decodes a signature tuple [yParity, r, s].
func decodeSignature(sigTuple []interface{}) (*signer.Signature, error) {
if len(sigTuple) != 3 {
Expand All @@ -294,13 +298,22 @@ func decodeSignature(sigTuple []interface{}) (*signer.Signature, error) {
if !ok {
return nil, fmt.Errorf("r is not bytes")
}
// Validate R size to prevent DoS via oversized signature components.
// Oversized values would cause a panic in RecoverAddress when using FillBytes.
if len(rBytes) > maxSignatureScalarBytes {
return nil, fmt.Errorf("r exceeds maximum size: got %d bytes, max %d", len(rBytes), maxSignatureScalarBytes)
}
r := new(big.Int).SetBytes(rBytes)

// Field 2: s
sBytes, ok := sigTuple[2].([]byte)
if !ok {
return nil, fmt.Errorf("s is not bytes")
}
// Validate S size to prevent DoS via oversized signature components.
if len(sBytes) > maxSignatureScalarBytes {
return nil, fmt.Errorf("s exceeds maximum size: got %d bytes, max %d", len(sBytes), maxSignatureScalarBytes)
}
s := new(big.Int).SetBytes(sBytes)

return signer.NewSignature(r, s, yParity), nil
Expand Down
164 changes: 139 additions & 25 deletions pkg/transaction/serialization_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package transaction

import (
"encoding/hex"
"math/big"
"strings"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/tempoxyz/tempo-go/pkg/signer"
Expand Down Expand Up @@ -331,45 +333,84 @@ func TestDecodeAccessList(t *testing.T) {
}

func TestDecodeSignature(t *testing.T) {
big33Bytes := new(big.Int).Lsh(big.NewInt(1), 256)
max32Bytes := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 256), big.NewInt(1))

tests := []struct {
name string
input []interface{}
want *signer.Signature
wantErr bool
name string
r []byte
s []byte
yParity byte
want *signer.Signature
wantErr bool
wantErrStr string
}{
{
name: "valid signature",
input: []interface{}{
[]byte{0},
big.NewInt(12345).Bytes(),
big.NewInt(67890).Bytes(),
},
want: signer.NewSignature(big.NewInt(12345), big.NewInt(67890), 0),
name: "valid signature",
r: big.NewInt(12345).Bytes(),
s: big.NewInt(67890).Bytes(),
yParity: 0,
want: signer.NewSignature(big.NewInt(12345), big.NewInt(67890), 0),
},
{
name: "signature with yParity = 1",
input: []interface{}{
[]byte{1},
big.NewInt(12345).Bytes(),
big.NewInt(67890).Bytes(),
},
want: signer.NewSignature(big.NewInt(12345), big.NewInt(67890), 1),
name: "signature with yParity = 1",
r: big.NewInt(12345).Bytes(),
s: big.NewInt(67890).Bytes(),
yParity: 1,
want: signer.NewSignature(big.NewInt(12345), big.NewInt(67890), 1),
},
{
name: "invalid - wrong length",
input: []interface{}{
[]byte{0},
big.NewInt(12345).Bytes(),
},
wantErr: true,
name: "empty R and S",
r: []byte{},
s: []byte{},
yParity: 0,
want: signer.NewSignature(big.NewInt(0), big.NewInt(0), 0),
},
{
name: "max valid size (32 bytes)",
r: max32Bytes.Bytes(),
s: max32Bytes.Bytes(),
yParity: 0,
want: signer.NewSignature(max32Bytes, max32Bytes, 0),
},
{
name: "oversized R (33 bytes)",
r: big33Bytes.Bytes(),
s: big.NewInt(1).Bytes(),
yParity: 0,
wantErr: true,
wantErrStr: "r exceeds maximum size",
},
{
name: "oversized S (33 bytes)",
r: big.NewInt(1).Bytes(),
s: big33Bytes.Bytes(),
yParity: 0,
wantErr: true,
wantErrStr: "s exceeds maximum size",
},
{
name: "oversized R and S",
r: big33Bytes.Bytes(),
s: big33Bytes.Bytes(),
yParity: 0,
wantErr: true,
wantErrStr: "r exceeds maximum size",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := decodeSignature(tt.input)
input := []interface{}{
[]byte{tt.yParity},
tt.r,
tt.s,
}
got, err := decodeSignature(input)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
assert.Contains(t, err.Error(), tt.wantErrStr)
return
}
assert.NoError(t, err)
Expand All @@ -380,6 +421,16 @@ func TestDecodeSignature(t *testing.T) {
}
}

func TestDecodeSignature_InvalidTupleLength(t *testing.T) {
input := []interface{}{
[]byte{0},
big.NewInt(12345).Bytes(),
}
got, err := decodeSignature(input)
assert.Error(t, err)
assert.Nil(t, got)
}

func TestRoundtrip(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -568,6 +619,69 @@ func TestRoundtripWithOptions(t *testing.T) {
})
}

func TestDeserialize_OversizedSignature(t *testing.T) {
big33Bytes := new(big.Int).Lsh(big.NewInt(1), 256) // 33 bytes

tests := []struct {
name string
r []byte
s []byte
wantErrStr string
}{
{
name: "oversized R in fee payer signature",
r: big33Bytes.Bytes(),
s: big.NewInt(1).Bytes(),
wantErrStr: "r exceeds maximum size",
},
{
name: "oversized S in fee payer signature",
r: big.NewInt(1).Bytes(),
s: big33Bytes.Bytes(),
wantErrStr: "s exceeds maximum size",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rlpList := []interface{}{
big.NewInt(42424).Bytes(), // chainId
big.NewInt(1000000).Bytes(), // maxPriorityFeePerGas
big.NewInt(2000000).Bytes(), // maxFeePerGas
big.NewInt(21000).Bytes(), // gas
[]interface{}{ // calls
[]interface{}{
common.HexToAddress("0x1234567890123456789012345678901234567890").Bytes(),
big.NewInt(1000000).Bytes(),
[]byte{},
},
},
[]interface{}{}, // accessList
[]byte{}, // nonceKey
big.NewInt(1).Bytes(), // nonce
[]byte{}, // validBefore
[]byte{}, // validAfter
[]byte{}, // feeToken
[]interface{}{ // fee payer signature
[]byte{0},
tt.r,
tt.s,
},
[]interface{}{}, // authorizationList
}

rlpBytes, err := rlp.EncodeToBytes(rlpList)
assert.NoError(t, err)

serialized := "0x76" + hex.EncodeToString(rlpBytes)

_, err = Deserialize(serialized)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErrStr)
})
}
}

// Helper functions

func hexToBigInt(s string) *big.Int {
Expand Down
Loading