Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Fix proof direction and key bit order #22

Merged
merged 9 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mapstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type InvalidKeyError struct {
}

func (e *InvalidKeyError) Error() string {
return fmt.Sprintf("invalid key: %s", e.Key)
return fmt.Sprintf("invalid key: %x", e.Key)
}

// SimpleMap is a simple in-memory map.
Expand Down
8 changes: 4 additions & 4 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
}

// Recompute root.
for i := len(proof.SideNodes) - 1; i >= 0; i-- {
for i := 0; i < len(proof.SideNodes); i++ {
node := make([]byte, th.pathSize())
copy(node, proof.SideNodes[i])

if hasBit(path, i) == right {
if getBitAtFromMSB(path, len(proof.SideNodes)-1-i) == right {
currentHash, currentData = th.digestNode(node, currentHash)
} else {
currentHash, currentData = th.digestNode(currentHash, node)
Expand Down Expand Up @@ -170,7 +170,7 @@ func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkl
node := make([]byte, th.hasher.Size())
copy(node, proof.SideNodes[i])
if bytes.Equal(node, th.placeholder()) {
setBit(bitMask, i)
setBitAtFromMSB(bitMask, i)
} else {
compactedSideNodes = append(compactedSideNodes, node)
}
Expand All @@ -195,7 +195,7 @@ func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash) (SparseMer
decompactedSideNodes := make([][]byte, proof.NumSideNodes)
position := 0
for i := 0; i < proof.NumSideNodes; i++ {
if hasBit(proof.BitMask, i) == 1 {
if getBitAtFromMSB(proof.BitMask, i) == 1 {
decompactedSideNodes[i] = th.placeholder()
} else {
decompactedSideNodes[i] = proof.SideNodes[position]
Expand Down
4 changes: 2 additions & 2 deletions proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func TestCompactProofsSanityCheck(t *testing.T) {

// Case (compact proofs): unexpected bit mask length.
proof, _ = smt.ProveCompact([]byte("testKey1"))
proof.NumSideNodes = 1
proof.NumSideNodes = 10
if proof.sanityCheck(th) {
t.Error("sanity check incorrectly passed")
}
Expand All @@ -221,7 +221,7 @@ func TestCompactProofsSanityCheck(t *testing.T) {

// Case (compact proofs): unexpected number of sidenodes for number of side nodes.
proof, _ = smt.ProveCompact([]byte("testKey1"))
proof.SideNodes = proof.SideNodes[:1]
proof.SideNodes = append(proof.SideNodes, proof.SideNodes...)
if proof.sanityCheck(th) {
t.Error("sanity check incorrectly passed")
}
Expand Down
36 changes: 22 additions & 14 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error)
}

leftNode, rightNode := smt.th.parseNode(currentData)
if hasBit(path, i) == right {
if getBitAtFromMSB(path, i) == right {
currentHash = rightNode
} else {
currentHash = leftNode
Expand Down Expand Up @@ -188,7 +188,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte

var currentHash, currentData []byte
nonPlaceholderReached := false
for i := smt.depth() - 1; i >= 0; i-- {
for i := 0; i < len(sideNodes); i++ {
if sideNodes[i] == nil {
continue
}
Expand All @@ -215,7 +215,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte
}

if !nonPlaceholderReached && bytes.Equal(sideNode, smt.th.placeholder()) {
// We found another placeholder sibling node, keep going down the
// We found another placeholder sibling node, keep going up the
// tree until we find the first sibling that is not a placeholder.
continue
} else if !nonPlaceholderReached {
Expand All @@ -224,7 +224,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte
nonPlaceholderReached = true
}

if hasBit(path, i) == right {
if getBitAtFromMSB(path, len(sideNodes)-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
Expand Down Expand Up @@ -269,7 +269,7 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
commonPrefixCount = countCommonPrefix(path, actualPath)
}
if commonPrefixCount != smt.depth() {
if hasBit(path, commonPrefixCount) == right {
if getBitAtFromMSB(path, commonPrefixCount) == right {
currentHash, currentData = smt.th.digestNode(oldLeafHash, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, oldLeafHash)
Expand All @@ -283,11 +283,15 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
currentData = currentHash
}

for i := smt.depth() - 1; i >= 0; i-- {
for i := 0; i < smt.depth(); i++ {
sideNode := make([]byte, smt.th.pathSize())

if sideNodes[i] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > i {
// The offset from the bottom of the tree to the start of the side nodes
// i-offsetOfSideNodes is the index into sideNodes[]
offsetOfSideNodes := smt.depth() - len(sideNodes)

if i-offsetOfSideNodes < 0 || sideNodes[i-offsetOfSideNodes] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > smt.depth()-1-i {
// If there are no sidenodes at this height, but the number of
// bits that the paths of the two leaf nodes share in common is
// greater than this height, then we need to build up the tree
Expand All @@ -297,10 +301,10 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
continue
}
} else {
copy(sideNode, sideNodes[i])
copy(sideNode, sideNodes[i-offsetOfSideNodes])
}

if hasBit(path, i) == right {
if getBitAtFromMSB(path, smt.depth()-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
Expand All @@ -319,7 +323,7 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
// Returns an array of sibling nodes, the leaf hash found at that path and the
// leaf data. If the leaf is a placeholder, the leaf data is nil.
func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byte, []byte, []byte, error) {
sideNodes := make([][]byte, smt.depth())
var sideNodes [][]byte
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved

if bytes.Equal(root, smt.th.placeholder()) {
// If the root is a placeholder, there are no sidenodes to return.
Expand All @@ -340,11 +344,15 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt
leftNode, rightNode := smt.th.parseNode(currentData)

// Get sidenode depending on whether the path bit is on or off.
if hasBit(path, i) == right {
sideNodes[i] = leftNode
if getBitAtFromMSB(path, i) == right {
sideNodes = append(sideNodes, nil)
copy(sideNodes[1:], sideNodes)
sideNodes[0] = leftNode
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
nodeHash = rightNode
} else {
sideNodes[i] = rightNode
sideNodes = append(sideNodes, nil)
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
copy(sideNodes[1:], sideNodes)
sideNodes[0] = rightNode
nodeHash = leftNode
}

Expand Down
112 changes: 111 additions & 1 deletion smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,116 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
}
}

// Test known tree ops
func TestSparseMerkleTreeKnown(t *testing.T) {
h := newDummyHasher(sha256.New())
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, h)
var value []byte
var err error

baseKey := make([]byte, h.Size()+4)
key1 := make([]byte, h.Size()+4)
copy(key1, baseKey)
key1[4] = byte(0b00000000)
key2 := make([]byte, h.Size()+4)
copy(key2, baseKey)
key2[4] = byte(0b01000000)
key3 := make([]byte, h.Size()+4)
copy(key3, baseKey)
key3[4] = byte(0b10000000)
key4 := make([]byte, h.Size()+4)
copy(key4, baseKey)
key4[4] = byte(0b11000000)
key5 := make([]byte, h.Size()+4)
copy(key5, baseKey)
key5[4] = byte(0b11010000)

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key2, []byte("testValue2"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key3, []byte("testValue3"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key4, []byte("testValue4"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key5, []byte("testValue5"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}

value, err = smt.Get(key1)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue1"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key2)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key3)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue3"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key4)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue4"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key5)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue5"), value) {
t.Error("did not get correct value when getting non-empty key")
}

proof1, _ := smt.Prove(key1)
proof2, _ := smt.Prove(key2)
proof3, _ := smt.Prove(key3)
proof4, _ := smt.Prove(key4)
proof5, _ := smt.Prove(key5)
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), h, smt.Root())
err = dsmst.AddBranch(proof1, key1, []byte("testValue1"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof2, key2, []byte("testValue2"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof3, key3, []byte("testValue3"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof4, key4, []byte("testValue4"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof5, key5, []byte("testValue5"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
}

// Test tree operations when two leafs are immediate neighbours.
func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
h := newDummyHasher(sha256.New())
Expand All @@ -187,7 +297,7 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
key1[h.Size()+4-1] = byte(0)
key2 := make([]byte, h.Size()+4)
copy(key2, key1)
setBit(key2, (h.Size()+4)*8-1)
setBitAtFromMSB(key2, (h.Size()+4)*8-1)

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
Expand Down
14 changes: 8 additions & 6 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package smt

func hasBit(data []byte, position int) int {
if int(data[position/8])&(1<<(uint(position)%8)) > 0 {
// getBitAtFromMSB gets the bit at an offset from the most significant bit
func getBitAtFromMSB(data []byte, position int) int {
if int(data[position/8])&(1<<(8-1-uint(position)%8)) > 0 {
return 1
}
return 0
}

func setBit(data []byte, position int) {
// setBitAtFromMSB sets the bit at an offset from the most significant bit
func setBitAtFromMSB(data []byte, position int) {
n := int(data[position/8])
n |= (1 << (uint(position) % 8))
n |= (1 << (8 - 1 - uint(position)%8))
data[position/8] = byte(n)
}

func countSetBits(data []byte) int {
count := 0
for i := 0; i < len(data)*8; i++ {
if hasBit(data, i) == 1 {
if getBitAtFromMSB(data, i) == 1 {
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
count++
}
}
Expand All @@ -26,7 +28,7 @@ func countSetBits(data []byte) int {
func countCommonPrefix(data1 []byte, data2 []byte) int {
count := 0
for i := 0; i < len(data1)*8; i++ {
if hasBit(data1, i) == hasBit(data2, i) {
if getBitAtFromMSB(data1, i) == getBitAtFromMSB(data2, i) {
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
count++
} else {
break
Expand Down