Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

adds support for raw key instead of hashed key #64

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
go.sum
rach-id marked this conversation as resolved.
Show resolved Hide resolved
rach-id marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions CHANGELOG-PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Month, DD, YYYY
### BREAKING CHANGES

- [go package] (Link to PR) Description @username
- [smt](https://github.com/celestiaorg/smt/pull/64) Adds support for raw key instead of hashed key [@SweeXordious](https://github.com/SweeXordious)

### FEATURES

Expand Down
25 changes: 17 additions & 8 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,38 @@ package smt

import (
"crypto/sha256"
"fmt"
"strconv"
"testing"
)

func BenchmarkSparseMerkleTree_Update(b *testing.B) {
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
hasher := sha256.New()
smn, smv := NewSimpleMap(hasher.Size()), NewSimpleMap(9)
smt := NewSparseMerkleTree(smn, smv, hasher)

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s := strconv.Itoa(i)
_, _ = smt.Update([]byte(s), []byte(s))
s := fmt.Sprintf("%09d", i)
_, err := smt.Update([]byte(s), []byte(s))
if err != nil {
b.Error(err)
}
}
}

func BenchmarkSparseMerkleTree_Delete(b *testing.B) {
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
hasher := sha256.New()
smn, smv := NewSimpleMap(hasher.Size()), NewSimpleMap(9)
smt := NewSparseMerkleTree(smn, smv, hasher)

for i := 0; i < 100000; i++ {
s := strconv.Itoa(i)
_, _ = smt.Update([]byte(s), []byte(s))
s := fmt.Sprintf("%09d", i)
_, err := smt.Update([]byte(s), []byte(s))
if err != nil {
b.Error(err)
}
}

b.ResetTimer()
Expand Down
15 changes: 8 additions & 7 deletions bulk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ func TestSparseMerkleTree(t *testing.T) {

// Test all tree operations in bulk, with specified ratio probabilities of insert, update and delete.
func bulkOperations(t *testing.T, operations int, insert int, update int, delete int) {
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
hasher := sha256.New()
keyLen := 16 + rand.Intn(32)
smn, smv := NewSimpleMap(hasher.Size()), NewSimpleMap(keyLen)
smt := NewSparseMerkleTree(smn, smv, hasher)

max := insert + update + delete
kv := make(map[string]string)

for i := 0; i < operations; i++ {
n := rand.Intn(max)
if n < insert { // Insert
keyLen := 16 + rand.Intn(32)
key := make([]byte, keyLen)
rand.Read(key)

Expand Down Expand Up @@ -93,14 +94,14 @@ func bulkCheckAll(t *testing.T, smt *SparseMerkleTree, kv *map[string]string) {
if err != nil {
t.Errorf("error: %v", err)
}
if !VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.th.hasher) {
if !VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.th.hasher, smt.values.GetKeySize()) {
t.Error("Merkle proof failed to verify")
}
compactProof, err := smt.ProveCompact([]byte(k))
if err != nil {
t.Errorf("error: %v", err)
}
if !VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.th.hasher) {
if !VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.th.hasher, smt.values.GetKeySize()) {
t.Error("Merkle proof failed to verify")
}

Expand All @@ -114,12 +115,12 @@ func bulkCheckAll(t *testing.T, smt *SparseMerkleTree, kv *map[string]string) {
if v2 == "" {
continue
}
commonPrefix := countCommonPrefix(smt.th.path([]byte(k)), smt.th.path([]byte(k2)))
commonPrefix := countCommonPrefix([]byte(k), []byte(k2))
if commonPrefix != smt.depth() && commonPrefix > largestCommonPrefix {
largestCommonPrefix = commonPrefix
}
}
sideNodes, _, _, _, err := smt.sideNodesForRoot(smt.th.path([]byte(k)), smt.Root(), false)
sideNodes, _, _, _, err := smt.sideNodesForRoot([]byte(k), smt.Root(), false)
if err != nil {
t.Errorf("error: %v", err)
}
Expand Down
23 changes: 14 additions & 9 deletions deepsubtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ func NewDeepSparseMerkleSubTree(nodes, values MapStore, hasher hash.Hash, root [
//
// If the leaf may be updated (e.g. during a state transition fraud proof),
// an updatable proof should be used. See SparseMerkleTree.ProveUpdatable.
func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []byte, value []byte) error {
result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher)
func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []byte, value []byte, keySize int) error {
if len(key) != keySize {
return ErrWrongKeySize
}
result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher, keySize)
if !result {
return ErrBadProof
}

if !bytes.Equal(value, defaultValue) { // Membership proof.
if err := dsmst.values.Set(dsmst.th.path(key), value); err != nil {
if err := dsmst.values.Set(key, value); err != nil {
return err
}
}
Expand Down Expand Up @@ -64,6 +67,9 @@ func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []b
// Use if a key was _not_ previously added with AddBranch, otherwise use Get.
// Errors if the key cannot be reached by descending.
func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) {
if len(key) != smt.values.GetKeySize() {
return nil, ErrWrongKeySize
}
// Get tree's root
root := smt.Root()

Expand All @@ -72,29 +78,28 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) {
return defaultValue, nil
}

path := smt.th.path(key)
currentHash := root
for i := 0; i < smt.depth(); i++ {
currentData, err := smt.nodes.Get(currentHash)
if err != nil {
return nil, err
} else if smt.th.isLeaf(currentData) {
// We've reached the end. Is this the actual leaf?
p, _ := smt.th.parseLeaf(currentData)
if !bytes.Equal(path, p) {
p, _ := smt.th.parseLeaf(currentData, smt.values.GetKeySize())
if !bytes.Equal(key, p) {
// Nope. Therefore the key is actually empty.
return defaultValue, nil
}
// Otherwise, yes. Return the value.
value, err := smt.values.Get(path)
value, err := smt.values.Get(key)
if err != nil {
return nil, err
}
return value, nil
}

leftNode, rightNode := smt.th.parseNode(currentData)
if getBitAtFromMSB(path, i) == right {
if getBitAtFromMSB(key, i) == right {
currentHash = rightNode
} else {
currentHash = leftNode
Expand All @@ -109,7 +114,7 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) {
// The following lines of code should only be reached if the path is 256
// nodes high, which should be very unlikely if the underlying hash function
// is collision-resistant.
value, err := smt.values.Get(path)
value, err := smt.values.Get(key)
if err != nil {
return nil, err
}
Expand Down
75 changes: 58 additions & 17 deletions deepsubtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,54 @@ import (
"testing"
)

func TestDeepSubTreeKeySizeChecks(t *testing.T) {
hasher := sha256.New()
keySize := len([]byte("testKey1"))
smn, smv := NewSimpleMap(hasher.Size()), NewSimpleMap(keySize)
smt := NewSparseMerkleTree(smn, smv, hasher)

_, err := smt.Update([]byte("testKey1"), []byte("testValue1"))
if err != nil {
t.Errorf("couldn't update smt. exception: %v", err)
}

proof, err := smt.Prove([]byte("testKey1"))
if err != nil {
t.Errorf("couldn't prove existing key. Actual exception: %v", err)
}

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(hasher.Size()), NewSimpleMap(keySize), sha256.New(), smt.Root())

err = dsmst.AddBranch(proof, randomBytes(keySize+1), []byte("testValue1"), smt.values.GetKeySize())
if err != ErrWrongKeySize {
t.Errorf("should have complained of `keySize + 1` when adding branch. Actual exception: %v", err)
}

err = dsmst.AddBranch(proof, randomBytes(keySize-1), []byte("testValue1"), smt.values.GetKeySize())
if err != ErrWrongKeySize {
t.Errorf("should have complained of `keySize - 1` when adding branch. Actual exception: %v", err)
}

_, err = dsmst.GetDescend(randomBytes(keySize + 1))
if err != ErrWrongKeySize {
t.Errorf("should have complained of `keySize + 1` when getting descend. Actual exception: %v", err)
}

_, err = dsmst.GetDescend(randomBytes(keySize - 1))
if err != ErrWrongKeySize {
t.Errorf("should have complained of `keySize - 1` when getting descend. Actual exception: %v", err)
}
}

func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())
hasher := sha256.New()
smt := NewSparseMerkleTree(NewSimpleMap(hasher.Size()), NewSimpleMap(len([]byte("testKey1"))), hasher)

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
smt.Update([]byte("testKey3"), []byte("testValue3"))
smt.Update([]byte("testKey4"), []byte("testValue4"))
smt.Update([]byte("testKey6"), []byte("testValue6"))
_, _ = smt.Update([]byte("testKey1"), []byte("testValue1"))
_, _ = smt.Update([]byte("testKey2"), []byte("testValue2"))
_, _ = smt.Update([]byte("testKey3"), []byte("testValue3"))
_, _ = smt.Update([]byte("testKey4"), []byte("testValue4"))
_, _ = smt.Update([]byte("testKey6"), []byte("testValue6"))

originalRoot := make([]byte, len(smt.Root()))
copy(originalRoot, smt.Root())
Expand All @@ -23,16 +63,16 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
proof2, _ := smt.ProveUpdatable([]byte("testKey2"))
proof5, _ := smt.ProveUpdatable([]byte("testKey5"))

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(proof1, []byte("testKey1"), []byte("testValue1"))
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(hasher.Size()), NewSimpleMap(len([]byte("testKey1"))), sha256.New(), smt.Root())
err := dsmst.AddBranch(proof1, []byte("testKey1"), []byte("testValue1"), smt.values.GetKeySize())
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof2, []byte("testKey2"), []byte("testValue2"))
err = dsmst.AddBranch(proof2, []byte("testKey2"), []byte("testValue2"), smt.values.GetKeySize())
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof5, []byte("testKey5"), defaultValue)
err = dsmst.AddBranch(proof5, []byte("testKey5"), defaultValue, smt.values.GetKeySize())
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
Expand Down Expand Up @@ -141,18 +181,19 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
}

func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())
hasher := sha256.New()
smt := NewSparseMerkleTree(NewSimpleMap(hasher.Size()), NewSimpleMap(len([]byte("testKey1"))), hasher) // to be refactored

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
smt.Update([]byte("testKey3"), []byte("testValue3"))
smt.Update([]byte("testKey4"), []byte("testValue4"))
_, _ = smt.Update([]byte("testKey1"), []byte("testValue1"))
_, _ = smt.Update([]byte("testKey2"), []byte("testValue2"))
_, _ = smt.Update([]byte("testKey3"), []byte("testValue3"))
_, _ = smt.Update([]byte("testKey4"), []byte("testValue4"))

badProof, _ := smt.Prove([]byte("testKey1"))
badProof.SideNodes[0][0] = byte(0)

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(badProof, []byte("testKey1"), []byte("testValue1"))
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(hasher.Size()), NewSimpleMap(len([]byte("testKey1"))), hasher, smt.Root()) // to be refactored
err := dsmst.AddBranch(badProof, []byte("testKey1"), []byte("testValue1"), smt.values.GetKeySize())
if !errors.Is(err, ErrBadProof) {
t.Error("did not return ErrBadProof for bad proof input")
}
Expand Down
8 changes: 5 additions & 3 deletions fuzz/delete/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ func Fuzz(data []byte) int {
return -1
}

smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap()
tree := smt.NewSparseMerkleTree(smn, smv, sha256.New())
hasher := sha256.New()
keySize := 10
smn, smv := smt.NewSimpleMap(hasher.Size()), smt.NewSimpleMap(keySize)
tree := smt.NewSparseMerkleTree(smn, smv, hasher)
for i := 0; i < len(splits)-1; i += 2 {
key, value := splits[i], splits[i+1]
tree.Update(key, value)
_, _ = tree.Update(key, value)
}

deleteKey := splits[len(splits)-1]
Expand Down
19 changes: 11 additions & 8 deletions fuzz/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import (
"github.com/celestiaorg/smt"
)

// Fuzz FIXME
func Fuzz(input []byte) int {
if len(input) < 100 {
return 0
}
smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap()
tree := smt.NewSparseMerkleTree(smn, smv, sha256.New())
hasher := sha256.New()
keySize := 10
smn, smv := smt.NewSimpleMap(hasher.Size()), smt.NewSimpleMap(keySize)
tree := smt.NewSparseMerkleTree(smn, smv, hasher)
r := bytes.NewReader(input)
var keys [][]byte
key := func() []byte {
if readByte(r) < math.MaxUint8/2 {
k := make([]byte, readByte(r)/2)
r.Read(k)
_, _ = r.Read(k)
keys = append(keys, k)
return k
}
Expand All @@ -37,17 +40,17 @@ func Fuzz(input []byte) int {
op := op(int(b) % int(Noop))
switch op {
case Get:
tree.Get(key())
_, _ = tree.Get(key())
case Update:
value := make([]byte, 32)
binary.BigEndian.PutUint64(value, uint64(i))
tree.Update(key(), value)
_, _ = tree.Update(key(), value)
case Delete:
tree.Delete(key())
_, _ = tree.Delete(key())
case Prove:
tree.Prove(key())
_, _ = tree.Prove(key())
case Has:
tree.Has(key())
_, _ = tree.Has(key())
}
}
return 1
Expand Down
13 changes: 13 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package smt

import (
"math/rand"
"time"
)

func randomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
rach-id marked this conversation as resolved.
Show resolved Hide resolved
b := make([]byte, length)
rand.Read(b)
return b
}
Loading