diff --git a/core/state.go b/core/state.go index 5e5e89f0a3..32ae9d908c 100644 --- a/core/state.go +++ b/core/state.go @@ -42,6 +42,10 @@ type StateReader interface { ContractNonce(addr *felt.Felt) (*felt.Felt, error) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) Class(classHash *felt.Felt) (*DeclaredClass, error) + ClassTrie() (*trie.Trie, func() error, error) + StorageTrie() (*trie.Trie, func() error, error) + StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) + StateAndClassRoot() (*felt.Felt, *felt.Felt, error) } type State struct { @@ -733,3 +737,35 @@ func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDif return &reversed, nil } + +func (s *State) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + var storageRoot, classesRoot *felt.Felt + + sStorage, closer, err := s.storage() + if err != nil { + return nil, nil, err + } + + if storageRoot, err = sStorage.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + classes, closer, err := s.classesTrie() + if err != nil { + return nil, nil, err + } + + if classesRoot, err = classes.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + return storageRoot, classesRoot, nil +} diff --git a/core/trie/node.go b/core/trie/node.go index b62db62807..c56dde3603 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -18,7 +18,7 @@ type Node struct { } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) Hash(path *Key, hashFunc HashFunc) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -33,7 +33,7 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc HashFunc) *felt.Felt { path := path(nodeKey, parnetKey) return n.Hash(&path, hashFunc) } diff --git a/core/trie/proof.go b/core/trie/proof.go index 517ae60764..892deac667 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -13,7 +13,7 @@ var ( ) type ProofNode interface { - Hash(hash hashFunc) *felt.Felt + Hash(hash HashFunc) *felt.Felt Len() uint8 PrettyPrint() } @@ -23,7 +23,7 @@ type Binary struct { RightHash *felt.Felt } -func (b *Binary) Hash(hash hashFunc) *felt.Felt { +func (b *Binary) Hash(hash HashFunc) *felt.Felt { return hash(b.LeftHash, b.RightHash) } @@ -42,7 +42,7 @@ type Edge struct { Path *Key // path from parent to child } -func (e *Edge) Hash(hash hashFunc) *felt.Felt { +func (e *Edge) Hash(hash HashFunc) *felt.Felt { length := make([]byte, len(e.Path.bitset)) length[len(e.Path.bitset)-1] = e.Path.len pathFelt := e.Path.Felt() @@ -54,6 +54,11 @@ func (e *Edge) Len() uint8 { return e.Path.Len() } +func (e *Edge) PathInt() uint64 { + f := e.Path.Felt() + return f.Uint64() +} + func (e *Edge) PrettyPrint() { fmt.Printf(" Edge:\n") fmt.Printf(" Child: %v\n", e.Child) @@ -199,7 +204,7 @@ func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Fe // merges paths in the specified order [commonNodes..., leftNodes..., rightNodes...] // ordering of the merged path is not important // since SplitProofPath can discover the left and right paths using the merged path and the rootHash -func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNode, *felt.Felt, error) { +func MergeProofPaths(leftPath, rightPath []ProofNode, hash HashFunc) ([]ProofNode, *felt.Felt, error) { merged := []ProofNode{} minLen := min(len(leftPath), len(rightPath)) @@ -236,7 +241,7 @@ func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNod // SplitProofPath splits the merged proof path into two paths (left and right), which were merged before // it first validates that the merged path is not circular, the split happens at most once and rootHash exists // then calls traverseNodes to split the path to left and right paths -func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc) ([]ProofNode, []ProofNode, error) { +func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash HashFunc) ([]ProofNode, []ProofNode, error) { commonPath := []ProofNode{} leftPath := []ProofNode{} rightPath := []ProofNode{} @@ -316,7 +321,7 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { // verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 -func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { +func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash HashFunc) bool { expectedHash := root remainingPath := NewKey(key.len, key.bitset[:]) for i, proofNode := range proofs { @@ -363,7 +368,7 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode // and therefore it's hash won't match the expected root. // ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, - proofs [2][]ProofNode, hash hashFunc, + proofs [2][]ProofNode, hash HashFunc, ) (bool, error) { // Step 0: checks if len(keys) != len(values) { @@ -440,7 +445,7 @@ func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { } // compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key -func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) { +func compressNode(idx int, proofNodes []ProofNode, hashF HashFunc) (int, uint8, error) { parent := proofNodes[idx] if idx == len(proofNodes)-1 { @@ -474,7 +479,7 @@ func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, } func assignChild(i, compressedParent int, parentNode *Node, - nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc, + nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF HashFunc, ) (*Key, error) { childInd := i + compressedParent + 1 childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF) @@ -494,7 +499,7 @@ func assignChild(i, compressedParent int, parentNode *Node, // ProofToPath returns a set of storage nodes from the root to the end of the proof path. // The storage nodes will have the hashes of the children, but only the key of the child // along the path outlined by the proof. -func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { +func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF HashFunc) ([]StorageNode, error) { pathNodes := []StorageNode{} // Child keys that can't be derived are set to nilKey, so that we can store the node @@ -552,7 +557,7 @@ func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]Storag return pathNodes, nil } -func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool { +func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF HashFunc) bool { lastNode := pathNodes[len(pathNodes)-1].node noLeftMatch, noRightMatch := false, false if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) { @@ -607,7 +612,7 @@ func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key, return crntKey, err } -func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) { +func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF HashFunc) (*Key, error) { if childIdx > len(proofNodes)-1 { return nilKey, nil } diff --git a/core/trie/trie.go b/core/trie/trie.go index c03357d3af..28c75fd891 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/db" ) -type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt +type HashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children). // @@ -37,7 +37,7 @@ type Trie struct { rootKey *Key maxKey *felt.Felt storage *Storage - hash hashFunc + hash HashFunc dirtyNodes []*Key rootKeyIsDirty bool @@ -53,7 +53,7 @@ func NewTriePoseidon(storage *Storage, height uint8) (*Trie, error) { return newTrie(storage, height, crypto.Poseidon) } -func newTrie(storage *Storage, height uint8, hash hashFunc) (*Trie, error) { +func newTrie(storage *Storage, height uint8, hash HashFunc) (*Trie, error) { if height > felt.Bits { return nil, fmt.Errorf("max trie height is %d, got: %d", felt.Bits, height) } @@ -668,6 +668,10 @@ func (t *Trie) RootKey() *Key { return t.rootKey } +func (t *Trie) HashFunc() HashFunc { + return t.hash +} + func (t *Trie) Dump() { t.dump(0, nil) } diff --git a/rpc/storage.go b/rpc/storage.go index facaf47539..8f19ed7948 100644 --- a/rpc/storage.go +++ b/rpc/storage.go @@ -34,31 +34,196 @@ func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *js // // It follows the specification defined here: // https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L910 -func (h *Handler) StorageProof(classes, contracts []felt.Felt, storageKeys []StorageKeys) (*felt.Felt, *jsonrpc.Error) { +func (h *Handler) StorageProof(id BlockID, classes, contracts []felt.Felt, storageKeys []StorageKeys) (*StorageProofResult, *jsonrpc.Error) { stateReader, stateCloser, err := h.bcReader.HeadState() if err != nil { return nil, ErrInternal.CloneWithData(err) } defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageProof") - // TODO: Extend state reader interface? - s := stateReader.(*core.State) - clt, _, err := s.ClassTrie() + head, err := h.bcReader.Head() if err != nil { return nil, ErrInternal.CloneWithData(err) } - for _, elt := range classes { - feltBytes := elt.Bytes() - key := trie.NewKey(core.ContractStorageTrieHeight, feltBytes[:]) - nodes, err := trie.GetProof(&key, clt) - // adapt proofs to the expected format + + storageRoot, classRoot, err := stateReader.StateAndClassRoot() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + result := &StorageProofResult{ + GlobalRoots: &GlobalRoots{ + ContractsTreeRoot: storageRoot, + ClassesTreeRoot: classRoot, + BlockHash: head.Hash, + }, + } + + result.ClassesProof, err = getClassesProof(stateReader, classes) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsProof, err = getContractsProof(stateReader, contracts) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsStorageProofs, err = getContractsStorageProofs(stateReader, storageKeys) + if err != nil { + return nil, ErrInternal.CloneWithData(err) } - return nil, ErrUnexpectedError + return result, nil } +// StorageKeys represents an item in `contracts_storage_keys. parameter // https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L938 type StorageKeys struct { Contract felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` } + +// MerkleNode represents a proof node in a trie +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3632 +type MerkleNode interface{} + +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3644 +type MerkleBinaryNode struct { + Left *felt.Felt `json:"left"` + Right *felt.Felt `json:"right"` +} + +// TODO[pnowosie]: link to specs +type MerkleEdgeNode struct { + Path *felt.Felt `json:"path"` + Length int `json:"length"` + Child *felt.Felt `json:"child"` +} + +// TODO[pnowosie]: link to specs +type MerkleLeafNode struct { + Value *felt.Felt `json:"value"` +} + +// HashToNode represents an item in `NODE_HASH_TO_NODE_MAPPING` specified here +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3667 +type HashToNode struct { + Hash *felt.Felt `json:"node_hash"` + Node MerkleNode `json:"node"` +} + +// TODO[pnowosie]: link to specs +type LeafData struct { + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` +} + +// TODO[pnowosie]: link to specs +type ContractProof struct { + Nodes []*HashToNode `json:"nodes"` + LeavesData []*LeafData `json:"contract_leaves_data"` +} + +// TODO[pnowosie]: link to specs +type GlobalRoots struct { + ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` + ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` + BlockHash *felt.Felt `json:"block_hash"` +} + +// TODO[pnowosie]: link to specs +type StorageProofResult struct { + ClassesProof []*HashToNode `json:"classes_proof"` + ContractsProof *ContractProof `json:"contracts_proof"` + ContractsStorageProofs [][]*HashToNode `json:"contracts_storage_proofs"` + GlobalRoots *GlobalRoots `json:"global_roots"` +} + +func getClassesProof(reader core.StateReader, classes []felt.Felt) ([]*HashToNode, error) { + ctrie, _, err := reader.ClassTrie() + if err != nil { + return nil, err + } + result := []*HashToNode{} + for _, class := range classes { + nodes, err := getProof(ctrie, &class) + if err != nil { + return nil, err + } + result = append(result, nodes...) + } + return result, nil +} + +func getContractsProof(reader core.StateReader, contracts []felt.Felt) (*ContractProof, error) { + strie, _, err := reader.StorageTrie() + if err != nil { + return nil, err + } + + result := &ContractProof{ + Nodes: []*HashToNode{}, + LeavesData: make([]*LeafData, 0, len(contracts)), + } + + for _, contract := range contracts { + leafData := &LeafData{} + leafData.Nonce, err = reader.ContractNonce(&contract) + if err != nil { + return nil, err + } + leafData.ClassHash, err = reader.ContractClassHash(&contract) + if err != nil { + return nil, err + } + result.LeavesData = append(result.LeavesData, leafData) + + nodes, err := getProof(strie, &contract) + if err != nil { + return nil, err + } + result.Nodes = append(result.Nodes, nodes...) + } + + return result, nil +} + +func getContractsStorageProofs(reader core.StateReader, keys []StorageKeys) ([][]*HashToNode, error) { + return nil, nil +} + +func getProof(t *trie.Trie, elt *felt.Felt) ([]*HashToNode, error) { + feltBytes := elt.Bytes() + key := trie.NewKey(core.ContractStorageTrieHeight, feltBytes[:]) + nodes, err := trie.GetProof(&key, t) + if err != nil { + return nil, err + } + + // adapt proofs to the expected format + hashnodes := make([]*HashToNode, len(nodes)) + for i, node := range nodes { + var merkle MerkleNode + + if binary, ok := node.(*trie.Binary); ok { + merkle = &MerkleBinaryNode{ + Left: binary.LeftHash, + Right: binary.RightHash, + } + } + if edge, ok := node.(*trie.Edge); ok { + f := edge.Path.Felt() + merkle = &MerkleEdgeNode{ + Path: &f, // TODO[pnowosie]: specs says its int + Length: int(edge.Len()), + Child: edge.Child, + } + } + + hashnodes[i] = &HashToNode{ + Hash: node.Hash(t.HashFunc()), + Node: merkle, + } + } + + return hashnodes, nil +} diff --git a/rpc/storage_test.go b/rpc/storage_test.go index 33cf25f0cc..da657867f4 100644 --- a/rpc/storage_test.go +++ b/rpc/storage_test.go @@ -103,10 +103,12 @@ func TestStorageProof(t *testing.T) { log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, "", log) + blockLatest := rpc.BlockID{Latest: true} + t.Run("empty blockchain", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -114,7 +116,7 @@ func TestStorageProof(t *testing.T) { t.Run("class trie hash does not exist in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -122,7 +124,7 @@ func TestStorageProof(t *testing.T) { t.Run("class trie hash exists in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -130,7 +132,7 @@ func TestStorageProof(t *testing.T) { t.Run("storage trie address does not exist in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -138,7 +140,7 @@ func TestStorageProof(t *testing.T) { t.Run("storage trie address exists in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -146,7 +148,7 @@ func TestStorageProof(t *testing.T) { t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -154,7 +156,7 @@ func TestStorageProof(t *testing.T) { t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -162,7 +164,7 @@ func TestStorageProof(t *testing.T) { t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr) @@ -170,7 +172,7 @@ func TestStorageProof(t *testing.T) { t.Run("class & storage tries proofs requested", func(t *testing.T) { //mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - proof, rpcErr := handler.StorageProof(nil, nil, nil) + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) require.Nil(t, proof) assert.Equal(t, rpc.ErrUnexpectedError, rpcErr)