diff --git a/core/trie/key.go b/core/trie/key.go index 208f9fcd89..cd254a2387 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -1,10 +1,9 @@ package trie import ( - "bytes" "encoding/hex" - "errors" "fmt" + "io" "math/big" "github.com/NethermindEth/juno/core/felt" @@ -61,7 +60,7 @@ func (k *Key) MostSignificantBits(n uint8) (*Key, error) { func (k *Key) SubKey(n uint8) (*Key, error) { if n > k.len { - return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len)) + return nil, fmt.Errorf("cannot subtract key of length %d from key of length %d", n, k.len) } if n == k.len { return &Key{}, nil @@ -95,8 +94,8 @@ func (k *Key) unusedBytes() []byte { return k.bitset[:len(k.bitset)-int(k.bytesNeeded())] } -func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) { - if err := buf.WriteByte(k.len); err != nil { +func (k *Key) WriteTo(buf io.Writer) (int64, error) { + if _, err := buf.Write([]byte{k.len}); err != nil { return 0, err } diff --git a/core/trie/key_test.go b/core/trie/key_test.go index 32b08b9e06..f3a7dc0842 100644 --- a/core/trie/key_test.go +++ b/core/trie/key_test.go @@ -2,6 +2,7 @@ package trie_test import ( "bytes" + "errors" "testing" "github.com/NethermindEth/juno/core/felt" @@ -153,3 +154,60 @@ func TestTruncate(t *testing.T) { }) } } + +func TestKeyErrorHandling(t *testing.T) { + t.Run("passed too long key bytes panics", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + require.Contains(t, r.(string), "bytes does not fit in bitset") + }() + tooLongKeyB := make([]byte, 33) + trie.NewKey(8, tooLongKeyB) + }) + t.Run("MostSignificantBits n greater than key length", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + _, err := key.MostSignificantBits(9) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot take 9 bits from key of length 8") + }) + t.Run("MostSignificantBits equals key length return copy of key", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + kCopy, err := key.MostSignificantBits(8) + require.NoError(t, err) + require.Equal(t, key, *kCopy) + }) + t.Run("SubKey n greater than key length", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + _, err := key.SubKey(9) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot subtract key of length 9 from key of length 8") + }) + t.Run("SubKey n equals k length returns empty key", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + kCopy, err := key.SubKey(8) + require.NoError(t, err) + require.Equal(t, trie.Key{}, *kCopy) + }) + t.Run("delete more bits than key length panics", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + require.Contains(t, r.(string), "deleting more bits than there are") + }() + key := trie.NewKey(8, []byte{0x01}) + key.DeleteLSB(9) + }) + t.Run("WriteTo returns error", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + wrote, err := key.WriteTo(&errorBuffer{}) + require.Error(t, err) + require.Equal(t, int64(0), wrote) + }) +} + +type errorBuffer struct{} + +func (*errorBuffer) Write([]byte) (int, error) { + return 0, errors.New("expected to fail") +} diff --git a/core/trie/node.go b/core/trie/node.go index c56dde3603..1e172e256a 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -1,8 +1,8 @@ package trie import ( - "bytes" "errors" + "io" "github.com/NethermindEth/juno/core/felt" ) @@ -38,7 +38,7 @@ func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc HashFunc) *felt. return n.Hash(&path, hashFunc) } -func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { +func (n *Node) WriteTo(buf io.Writer) (int64, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") } diff --git a/core/trie/node_test.go b/core/trie/node_test.go index ccb52b3eac..3ac71a9241 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -1,7 +1,9 @@ package trie_test import ( + "bytes" "encoding/hex" + "errors" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -26,3 +28,33 @@ func TestNodeHash(t *testing.T) { assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } + +func TestNodeErrorHandling(t *testing.T) { + t.Run("WriteTo node value is nil", func(t *testing.T) { + node := trie.Node{} + var buffer bytes.Buffer + _, err := node.WriteTo(&buffer) + require.Error(t, err) + }) + t.Run("WriteTo returns error", func(t *testing.T) { + node := trie.Node{ + Value: new(felt.Felt).SetUint64(42), + Left: &trie.Key{}, + Right: &trie.Key{}, + } + + wrote, err := node.WriteTo(&errorBuffer{}) + require.Error(t, err) + require.Equal(t, int64(0), wrote) + }) + t.Run("UnmarshalBinary returns error", func(t *testing.T) { + node := trie.Node{} + + err := node.UnmarshalBinary([]byte{42}) + require.Equal(t, errors.New("size of input data is less than felt size"), err) + + bs := new(felt.Felt).Bytes() + err = node.UnmarshalBinary(append(bs[:], 0, 0, 42)) + require.Equal(t, errors.New("the node does not contain both left and right hash"), err) + }) +} diff --git a/core/trie/proof.go b/core/trie/proof.go index ea32a02007..2bccef9a77 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -274,94 +274,6 @@ func VerifyProof(root *felt.Felt, key *Key, proofSet *ProofSet, hash HashFunc) ( } } -// VerifyRangeProof verifies the range proof for the given range of keys. -// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values. -// If the root of the reconstructed trie matches the supplied root, then the verification passes. -// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value, -// and therefore its hash won't match the expected root. -// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 -// -//nolint:gocyclo -func VerifyRangeProof(root, firstKey *felt.Felt, keys, values []*felt.Felt, proofSet *ProofSet, hash HashFunc) (bool, error) { - // Ensure the number of keys and values are the same - if len(keys) != len(values) { - return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values)) - } - - // Ensure all keys are monotonic increasing - for i := 0; i < len(keys)-1; i++ { - if keys[i].Cmp(keys[i+1]) >= 0 { - return false, errors.New("keys are not monotonic increasing") - } - } - - // Ensure the range contains no deletions - for _, value := range values { - if value.Equal(&felt.Zero) { - return false, errors.New("range contains deletion") - } - } - - // Special case: no edge proof at all, given range is the whole leaf set in the trie - if proofSet == nil { - tr, err := NewTriePedersen(newMemStorage(), 251) //nolint:mnd - if err != nil { - return false, err - } - - for index, key := range keys { - _, err = tr.Put(key, values[index]) - if err != nil { - return false, err - } - } - - recomputedRoot, err := tr.Root() - if err != nil { - return false, err - } - - if !recomputedRoot.Equal(root) { - return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) - } - - return true, nil - } - - proofList := proofSet.List() - lastKey := keys[len(keys)-1] - - // Construct the left proof path - leftProofPath, err := ProofToPath(proofList, &Key{len: 251, bitset: firstKey.Bytes()}, hash) - if err != nil { - return false, err - } - - // Construct the right proof path - rightProofPath, err := ProofToPath(proofList, &Key{len: 251, bitset: lastKey.Bytes()}, hash) - if err != nil { - return false, err - } - - // Build the trie from the proof paths and the key-value pairs - tr, err := BuildTrie(leftProofPath, rightProofPath, keys, values) - if err != nil { - return false, err - } - - // Verify that the recomputed root hash matches the provided root hash - recomputedRoot, err := tr.Root() - if err != nil { - return false, err - } - - if !recomputedRoot.Equal(root) { - return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) - } - - return true, nil -} - // compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key func compressNode(idx int, proofNodes []ProofNode, hashF HashFunc) (int, uint8, error) { parent := proofNodes[idx] diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 56f0f402df..5bb59c9b25 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -428,56 +428,6 @@ func TestProve(t *testing.T) { }) } -func TestProveNKeys(t *testing.T) { - t.Parallel() - - n := 10000 - tempTrie := buildTrieWithNKeys(t, n) - - for i := 1; i < n+1; i++ { - keyFelt := new(felt.Felt).SetUint64(uint64(i)) - key := tempTrie.FeltToKey(keyFelt) - - proofSet := trie.NewProofSet() - err := tempTrie.Prove(keyFelt, proofSet) - require.NoError(t, err) - - root, err := tempTrie.Root() - require.NoError(t, err) - - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) - if err != nil { - t.Fatalf("failed for key %s", key.String()) - } - require.Equal(t, val, keyFelt) - } -} - -func TestProveNKeysWithNonExistentKeys(t *testing.T) { - t.Parallel() - - n := 10000 - tempTrie := buildTrieWithNKeys(t, n) - - for i := 1; i < n+1; i++ { - keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) - key := tempTrie.FeltToKey(keyFelt) - - proofSet := trie.NewProofSet() - err := tempTrie.Prove(keyFelt, proofSet) - require.NoError(t, err) - - root, err := tempTrie.Root() - require.NoError(t, err) - - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) - if err != nil { - t.Fatalf("failed for key %s", key.String()) - } - require.Equal(t, &felt.Zero, val) - } -} - func TestProveRandomTrie(t *testing.T) { n := 1000 tempTrie, keys := buildRandomTrie(t, n) @@ -742,147 +692,6 @@ func TestProofToPath(t *testing.T) { }) } -// func TestVerifyRangeProof(t *testing.T) { -// t.Run("VPR two proofs, single key trie", func(t *testing.T) { -// // Node (edge path 249) -// // / \ -// // Node (binary) 0x6 (leaf) -// // / \ -// // 0x4 0x5 (leaf, leaf) - -// zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() -// zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) -// twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() -// twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - -// tri := build3KeyTrie(t) -// keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} -// values := []*felt.Felt{new(felt.Felt).SetUint64(5)} -// proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} -// proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(6)} -// rootCommitment, err := tri.Root() -// require.NoError(t, err) -// proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) -// require.NoError(t, err) -// verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) -// require.NoError(t, err) -// require.True(t, verif) -// }) - -// t.Run("VPR all keys provided, no proofs needed", func(t *testing.T) { -// // Node (edge path 249) -// // / \ -// // Node (binary) 0x6 (leaf) -// // / \ -// // 0x4 0x5 (leaf, leaf) -// tri := build3KeyTrie(t) -// keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} -// values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} -// proofKeys := [2]*trie.Key{} -// proofValues := [2]*felt.Felt{} -// proofs := [2][]trie.ProofNode{} -// rootCommitment, err := tri.Root() -// require.NoError(t, err) -// verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) -// require.NoError(t, err) -// require.True(t, verif) -// }) - -// t.Run("VPR left proof, all right keys", func(t *testing.T) { -// // Node (edge path 249) -// // / \ -// // Node (binary) 0x6 (leaf) -// // / \ -// // 0x4 0x5 (leaf, leaf) - -// zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() -// zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - -// tri := build3KeyTrie(t) -// keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} -// values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} -// proofKeys := [2]*trie.Key{&zeroLeafkey} -// proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4)} -// leftProof, err := trie.GetProof(proofKeys[0], tri) -// require.NoError(t, err) -// proofs := [2][]trie.ProofNode{leftProof} -// rootCommitment, err := tri.Root() -// require.NoError(t, err) -// verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) -// require.NoError(t, err) -// require.True(t, verif) -// }) - -// t.Run("VPR right proof, all left keys", func(t *testing.T) { -// // Node (edge path 249) -// // / \ -// // Node (binary) 0x6 (leaf) -// // / \ -// // 0x4 0x5 (leaf, leaf) -// twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() -// twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - -// tri := build3KeyTrie(t) -// keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1)} -// values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5)} -// proofKeys := [2]*trie.Key{nil, &twoLeafkey} -// proofValues := [2]*felt.Felt{nil, new(felt.Felt).SetUint64(6)} -// rightProof, err := trie.GetProof(proofKeys[1], tri) -// require.NoError(t, err) -// proofs := [2][]trie.ProofNode{nil, rightProof} -// rootCommitment, err := tri.Root() -// require.NoError(t, err) -// verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) -// require.NoError(t, err) -// require.True(t, verif) -// }) - -// t.Run("VPR left proof, all inner keys, right proof with non-set key", func(t *testing.T) { -// zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() -// zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - -// threeFeltBytes := new(felt.Felt).SetUint64(3).Bytes() -// threeLeafkey := trie.NewKey(251, threeFeltBytes[:]) - -// tri := build4KeyTrie(t) -// keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} -// values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} -// proofKeys := [2]*trie.Key{&zeroLeafkey, &threeLeafkey} -// proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), nil} -// leftProof, err := trie.GetProof(proofKeys[0], tri) -// require.NoError(t, err) -// rightProof, err := trie.GetProof(proofKeys[1], tri) -// require.NoError(t, err) - -// proofs := [2][]trie.ProofNode{leftProof, rightProof} -// rootCommitment, err := tri.Root() -// require.NoError(t, err) - -// verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) -// require.NoError(t, err) -// require.True(t, verif) -// }) -// } - -func buildTrieWithNKeys(t *testing.T, numKeys int) *trie.Trie { - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - for i := 1; i < numKeys+1; i++ { - key := new(felt.Felt).SetUint64(uint64(i)) - _, err := tempTrie.Put(key, key) - require.NoError(t, err) - } - - require.NoError(t, tempTrie.Commit()) - - return tempTrie -} - func buildRandomTrie(t *testing.T, n int) (*trie.Trie, []*felt.Felt) { rrand := rand.New(rand.NewSource(3)) diff --git a/core/trie/proofset.go b/core/trie/proofset.go index 8ee8b5174c..cb20195b5b 100644 --- a/core/trie/proofset.go +++ b/core/trie/proofset.go @@ -47,14 +47,3 @@ func (ps *ProofSet) Size() int { return ps.size } - -// List returns a shallow copy of the proof set's node list. -func (ps *ProofSet) List() []ProofNode { - ps.lock.RLock() - defer ps.lock.RUnlock() - - nodes := make([]ProofNode, len(ps.nodeList)) - copy(nodes, ps.nodeList) - - return nodes -} diff --git a/rpc/storage_test.go b/rpc/storage_test.go index a778b8108c..0d32a2d863 100644 --- a/rpc/storage_test.go +++ b/rpc/storage_test.go @@ -302,7 +302,7 @@ func TestStorageProof(t *testing.T) { t.Parallel() contract := utils.HexToFelt(t, "0xabcd") - mockTrie.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(tempTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) @@ -317,8 +317,8 @@ func TestStorageProof(t *testing.T) { t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { t.Parallel() - contract := utils.HexToFelt(t, "0xabcd") - mockTrie.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) + contract := utils.HexToFelt(t, "0xadd0") + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(tempTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys)