Skip to content

Commit 65b7507

Browse files
authored
Fix and refactor trie proof logics (#2252)
1 parent 2b1b219 commit 65b7507

File tree

9 files changed

+1536
-1591
lines changed

9 files changed

+1536
-1591
lines changed

core/trie/key.go

+73-42
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package trie
33
import (
44
"bytes"
55
"encoding/hex"
6-
"errors"
76
"fmt"
87
"math/big"
98

109
"github.com/NethermindEth/juno/core/felt"
1110
)
1211

12+
var NilKey = &Key{len: 0, bitset: [32]byte{}}
13+
1314
type Key struct {
1415
len uint8
1516
bitset [32]byte
@@ -24,26 +25,6 @@ func NewKey(length uint8, keyBytes []byte) Key {
2425
return k
2526
}
2627

27-
func (k *Key) SubKey(n uint8) (*Key, error) {
28-
if n > k.len {
29-
return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len))
30-
}
31-
32-
newKey := &Key{len: n}
33-
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd
34-
35-
// Shift right by the number of bits that are not needed
36-
shift := k.len - n
37-
for i := len(newKey.bitset) - 1; i >= 0; i-- {
38-
newKey.bitset[i] >>= shift
39-
if i > 0 {
40-
newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift)
41-
}
42-
}
43-
44-
return newKey, nil
45-
}
46-
4728
func (k *Key) bytesNeeded() uint {
4829
const byteBits = 8
4930
return (uint(k.len) + (byteBits - 1)) / byteBits
@@ -96,31 +77,48 @@ func (k *Key) Equal(other *Key) bool {
9677
return k.len == other.len && k.bitset == other.bitset
9778
}
9879

99-
func (k *Key) Test(bit uint8) bool {
80+
// IsBitSet returns whether the bit at the given position is 1.
81+
// Position 0 represents the least significant (rightmost) bit.
82+
func (k *Key) IsBitSet(position uint8) bool {
10083
const LSB = uint8(0x1)
101-
byteIdx := bit / 8
84+
byteIdx := position / 8
10285
byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1]
103-
bitIdx := bit % 8
86+
bitIdx := position % 8
10487
return ((byteAtIdx >> bitIdx) & LSB) != 0
10588
}
10689

107-
func (k *Key) String() string {
108-
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))
109-
}
110-
111-
// DeleteLSB right shifts and shortens the key
112-
func (k *Key) DeleteLSB(n uint8) {
90+
// shiftRight removes n least significant bits from the key by performing a right shift
91+
// operation and reducing the key length. For example, if the key contains bits
92+
// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4).
93+
//
94+
// The operation is destructive - it modifies the key in place.
95+
func (k *Key) shiftRight(n uint8) {
11396
if k.len < n {
11497
panic("deleting more bits than there are")
11598
}
11699

100+
if n == 0 {
101+
return
102+
}
103+
117104
var bigInt big.Int
118105
bigInt.SetBytes(k.bitset[:])
119106
bigInt.Rsh(&bigInt, uint(n))
120107
bigInt.FillBytes(k.bitset[:])
121108
k.len -= n
122109
}
123110

111+
// MostSignificantBits returns a new key with the most significant n bits of the current key.
112+
func (k *Key) MostSignificantBits(n uint8) (*Key, error) {
113+
if n > k.len {
114+
return nil, fmt.Errorf("cannot get more bits than the key length")
115+
}
116+
117+
keyCopy := k.Copy()
118+
keyCopy.shiftRight(k.len - n)
119+
return &keyCopy, nil
120+
}
121+
124122
// Truncate truncates key to `length` bits by clearing the remaining upper bits
125123
func (k *Key) Truncate(length uint8) {
126124
k.len = length
@@ -136,20 +134,53 @@ func (k *Key) Truncate(length uint8) {
136134
}
137135
}
138136

139-
func (k *Key) RemoveLastBit() {
140-
if k.len == 0 {
141-
return
142-
}
137+
func (k *Key) String() string {
138+
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))
139+
}
143140

144-
k.len--
141+
// Copy returns a deep copy of the key
142+
func (k *Key) Copy() Key {
143+
newKey := Key{len: k.len}
144+
copy(newKey.bitset[:], k.bitset[:])
145+
return newKey
146+
}
145147

146-
unusedBytes := k.unusedBytes()
147-
clear(unusedBytes)
148+
func (k *Key) Bytes() [32]byte {
149+
var result [32]byte
150+
copy(result[:], k.bitset[:])
151+
return result
152+
}
148153

149-
// clear upper bits on the last used byte
150-
inUseBytes := k.inUseBytes()
151-
unusedBitsCount := 8 - (k.len % 8)
152-
if unusedBitsCount != 8 && len(inUseBytes) > 0 {
153-
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
154+
// findCommonKey finds the set of common MSB bits in two key bitsets.
155+
func findCommonKey(longerKey, shorterKey *Key) (Key, bool) {
156+
divergentBit := findDivergentBit(longerKey, shorterKey)
157+
158+
if divergentBit == 0 {
159+
return *NilKey, false
154160
}
161+
162+
commonKey := *shorterKey
163+
commonKey.shiftRight(shorterKey.Len() - divergentBit + 1)
164+
return commonKey, divergentBit == shorterKey.Len()+1
165+
}
166+
167+
// findDivergentBit finds the first bit that is different between two keys,
168+
// starting from the most significant bit of both keys.
169+
func findDivergentBit(longerKey, shorterKey *Key) uint8 {
170+
divergentBit := uint8(0)
171+
for divergentBit <= shorterKey.Len() &&
172+
longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) {
173+
divergentBit++
174+
}
175+
return divergentBit
176+
}
177+
178+
func isSubset(longerKey, shorterKey *Key) bool {
179+
divergentBit := findDivergentBit(longerKey, shorterKey)
180+
return divergentBit == shorterKey.Len()+1
181+
}
182+
183+
func FeltToKey(length uint8, key *felt.Felt) Key {
184+
keyBytes := key.Bytes()
185+
return NewKey(length, keyBytes[:])
155186
}

core/trie/key_test.go

+115-41
Original file line numberDiff line numberDiff line change
@@ -68,47 +68,6 @@ func BenchmarkKeyEncoding(b *testing.B) {
6868
}
6969
}
7070

71-
func TestKeyTest(t *testing.T) {
72-
key := trie.NewKey(44, []byte{0x10, 0x02})
73-
for i := 0; i < int(key.Len()); i++ {
74-
assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i)
75-
}
76-
}
77-
78-
func TestDeleteLSB(t *testing.T) {
79-
key := trie.NewKey(16, []byte{0xF3, 0x04})
80-
81-
tests := map[string]struct {
82-
shiftAmount uint8
83-
expectedKey trie.Key
84-
}{
85-
"delete 0 bits": {
86-
shiftAmount: 0,
87-
expectedKey: key,
88-
},
89-
"delete 4 bits": {
90-
shiftAmount: 4,
91-
expectedKey: trie.NewKey(12, []byte{0x0F, 0x30}),
92-
},
93-
"delete 8 bits": {
94-
shiftAmount: 8,
95-
expectedKey: trie.NewKey(8, []byte{0xF3}),
96-
},
97-
"delete 9 bits": {
98-
shiftAmount: 9,
99-
expectedKey: trie.NewKey(7, []byte{0x79}),
100-
},
101-
}
102-
103-
for desc, test := range tests {
104-
t.Run(desc, func(t *testing.T) {
105-
copyKey := key
106-
copyKey.DeleteLSB(test.shiftAmount)
107-
assert.Equal(t, test.expectedKey, copyKey)
108-
})
109-
}
110-
}
111-
11271
func TestTruncate(t *testing.T) {
11372
tests := map[string]struct {
11473
key trie.Key
@@ -153,3 +112,118 @@ func TestTruncate(t *testing.T) {
153112
})
154113
}
155114
}
115+
116+
func TestKeyTest(t *testing.T) {
117+
key := trie.NewKey(44, []byte{0x10, 0x02})
118+
for i := 0; i < int(key.Len()); i++ {
119+
assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i)
120+
}
121+
}
122+
123+
func TestIsBitSet(t *testing.T) {
124+
tests := map[string]struct {
125+
key trie.Key
126+
position uint8
127+
expected bool
128+
}{
129+
"single byte, LSB set": {
130+
key: trie.NewKey(8, []byte{0x01}),
131+
position: 0,
132+
expected: true,
133+
},
134+
"single byte, MSB set": {
135+
key: trie.NewKey(8, []byte{0x80}),
136+
position: 7,
137+
expected: true,
138+
},
139+
"single byte, middle bit set": {
140+
key: trie.NewKey(8, []byte{0x10}),
141+
position: 4,
142+
expected: true,
143+
},
144+
"single byte, bit not set": {
145+
key: trie.NewKey(8, []byte{0xFE}),
146+
position: 0,
147+
expected: false,
148+
},
149+
"multiple bytes, LSB set": {
150+
key: trie.NewKey(16, []byte{0x00, 0x02}),
151+
position: 1,
152+
expected: true,
153+
},
154+
"multiple bytes, MSB set": {
155+
key: trie.NewKey(16, []byte{0x01, 0x00}),
156+
position: 8,
157+
expected: true,
158+
},
159+
"multiple bytes, no bits set": {
160+
key: trie.NewKey(16, []byte{0x00, 0x00}),
161+
position: 7,
162+
expected: false,
163+
},
164+
"check all bits in pattern": {
165+
key: trie.NewKey(8, []byte{0xA5}), // 10100101
166+
position: 0,
167+
expected: true,
168+
},
169+
}
170+
171+
// Additional test for 0xA5 pattern
172+
key := trie.NewKey(8, []byte{0xA5}) // 10100101
173+
expectedBits := []bool{true, false, true, false, false, true, false, true}
174+
for i, expected := range expectedBits {
175+
assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i)
176+
}
177+
178+
for name, tc := range tests {
179+
t.Run(name, func(t *testing.T) {
180+
result := tc.key.IsBitSet(tc.position)
181+
assert.Equal(t, tc.expected, result)
182+
})
183+
}
184+
}
185+
186+
func TestMostSignificantBits(t *testing.T) {
187+
tests := []struct {
188+
name string
189+
key trie.Key
190+
n uint8
191+
want trie.Key
192+
expectErr bool
193+
}{
194+
{
195+
name: "Valid case",
196+
key: trie.NewKey(8, []byte{0b11110000}),
197+
n: 4,
198+
want: trie.NewKey(4, []byte{0b00001111}),
199+
expectErr: false,
200+
},
201+
{
202+
name: "Request more bits than available",
203+
key: trie.NewKey(8, []byte{0b11110000}),
204+
n: 10,
205+
want: trie.Key{},
206+
expectErr: true,
207+
},
208+
{
209+
name: "Zero bits requested",
210+
key: trie.NewKey(8, []byte{0b11110000}),
211+
n: 0,
212+
want: trie.NewKey(0, []byte{}),
213+
expectErr: false,
214+
},
215+
}
216+
217+
for _, tt := range tests {
218+
t.Run(tt.name, func(t *testing.T) {
219+
got, err := tt.key.MostSignificantBits(tt.n)
220+
if (err != nil) != tt.expectErr {
221+
t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr)
222+
return
223+
}
224+
if !tt.expectErr && !got.Equal(&tt.want) {
225+
t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want)
226+
}
227+
})
228+
}
229+
}

core/trie/node.go

+53
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package trie
33
import (
44
"bytes"
55
"errors"
6+
"fmt"
67

78
"github.com/NethermindEth/juno/core/felt"
89
)
@@ -138,3 +139,55 @@ func (n *Node) UnmarshalBinary(data []byte) error {
138139
n.RightHash.SetBytes(data[:felt.Bytes])
139140
return nil
140141
}
142+
143+
func (n *Node) String() string {
144+
return fmt.Sprintf("Node{Value: %s, Left: %s, Right: %s, LeftHash: %s, RightHash: %s}", n.Value, n.Left, n.Right, n.LeftHash, n.RightHash)
145+
}
146+
147+
// Update the receiver with non-nil fields from the `other` Node.
148+
// If a field is non-nil in both Nodes, they must be equal, or an error is returned.
149+
//
150+
// This method modifies the receiver in-place and returns an error if any field conflicts are detected.
151+
//
152+
//nolint:gocyclo
153+
func (n *Node) Update(other *Node) error {
154+
// First validate all fields for conflicts
155+
if n.Value != nil && other.Value != nil && !n.Value.Equal(other.Value) {
156+
return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value)
157+
}
158+
159+
if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) {
160+
return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left)
161+
}
162+
163+
if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) {
164+
return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right)
165+
}
166+
167+
if n.LeftHash != nil && other.LeftHash != nil && !n.LeftHash.Equal(other.LeftHash) {
168+
return fmt.Errorf("conflicting LeftHash: %v != %v", n.LeftHash, other.LeftHash)
169+
}
170+
171+
if n.RightHash != nil && other.RightHash != nil && !n.RightHash.Equal(other.RightHash) {
172+
return fmt.Errorf("conflicting RightHash: %v != %v", n.RightHash, other.RightHash)
173+
}
174+
175+
// After validation, perform all updates
176+
if other.Value != nil {
177+
n.Value = other.Value
178+
}
179+
if other.Left != nil && !other.Left.Equal(NilKey) {
180+
n.Left = other.Left
181+
}
182+
if other.Right != nil && !other.Right.Equal(NilKey) {
183+
n.Right = other.Right
184+
}
185+
if other.LeftHash != nil {
186+
n.LeftHash = other.LeftHash
187+
}
188+
if other.RightHash != nil {
189+
n.RightHash = other.RightHash
190+
}
191+
192+
return nil
193+
}

0 commit comments

Comments
 (0)