Skip to content

Commit

Permalink
Add Cmp()
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Jan 7, 2025
1 parent c6c8183 commit 38d930b
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 13 deletions.
54 changes: 46 additions & 8 deletions core/trie/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool {
return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen))
}

// Sets the bit array to the most significant 'n' bits of x.
// Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive).
// If n >= x.len, the bit array is an exact copy of x.
// For example:
//
Expand Down Expand Up @@ -480,14 +480,10 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray {
}

// Sets the bit array to a single bit.
func (b *BitArray) SetBit(bit bool) *BitArray {
func (b *BitArray) SetBit(bit uint8) *BitArray {
b.len = 1
if bit {
b.words[0] = 1
} else {
b.words[0] = 0
}
b.truncateToLength()
b.words[0] = uint64(bit & 1)
b.words[1], b.words[2], b.words[3] = 0, 0, 0
return b
}

Expand All @@ -503,7 +499,16 @@ func (b *BitArray) Copy() BitArray {
return res
}

// Returns the encoded string representation of the bit array.
func (b *BitArray) EncodedString() string {
var res []byte
res = append(res, b.len)
res = append(res, b.Bytes()...)
return string(res)

Check warning on line 507 in core/trie/bitarray.go

View check run for this annotation

Codecov / codecov/patch

core/trie/bitarray.go#L503-L507

Added lines #L503 - L507 were not covered by tests
}

// Returns a string representation of the bit array.
// This is typically used for logging or debugging.
func (b *BitArray) String() string {
return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes()))

Check warning on line 513 in core/trie/bitarray.go

View check run for this annotation

Codecov / codecov/patch

core/trie/bitarray.go#L512-L513

Added lines #L512 - L513 were not covered by tests
}
Expand Down Expand Up @@ -615,3 +620,36 @@ func findFirstSetBit(b *BitArray) uint8 {
// All bits are zero, no set bit found
return 0
}

// Cmp compares two bit arrays lexicographically.
// The comparison is first done by length, then by content if lengths are equal.
// Returns:
//
// -1 if b < x
// 0 if b == x
// 1 if b > x
func (b *BitArray) Cmp(x *BitArray) int {
// First compare lengths
if b.len < x.len {
return -1
}
if b.len > x.len {
return 1
}

// Lengths are equal, compare the actual bits
d0, carry := bits.Sub64(b.words[0], x.words[0], 0)
d1, carry := bits.Sub64(b.words[1], x.words[1], carry)
d2, carry := bits.Sub64(b.words[2], x.words[2], carry)
d3, carry := bits.Sub64(b.words[3], x.words[3], carry)

if carry == 1 {
return -1
}

if d0|d1|d2|d3 == 0 {
return 0
}

return 1
}
94 changes: 89 additions & 5 deletions core/trie/bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1689,20 +1689,20 @@ func TestSetFeltValidation(t *testing.T) {
func TestSetBit(t *testing.T) {
tests := []struct {
name string
bit bool
bit uint8
want BitArray
}{
{
name: "set bit false",
bit: false,
name: "set bit 0",
bit: 0,
want: BitArray{
len: 1,
words: [4]uint64{0, 0, 0, 0},
},
},
{
name: "set bit true",
bit: true,
name: "set bit 1",
bit: 1,
want: BitArray{
len: 1,
words: [4]uint64{1, 0, 0, 0},
Expand All @@ -1719,3 +1719,87 @@ func TestSetBit(t *testing.T) {
})
}
}

func TestCmp(t *testing.T) {
tests := []struct {
name string
x BitArray
y BitArray
want int
}{
{
name: "equal empty arrays",
x: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}},
y: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}},
want: 0,
},
{
name: "equal non-empty arrays",
x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
want: 0,
},
{
name: "different lengths - x shorter",
x: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}},
y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
want: -1,
},
{
name: "different lengths - x longer",
x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
y: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}},
want: 1,
},
{
name: "same length, x < y in first word",
x: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}},
y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
want: -1,
},
{
name: "same length, x > y in first word",
x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}},
y: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}},
want: 1,
},
{
name: "same length, difference in last word",
x: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFFF}},
y: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFF0}},
want: 1,
},
{
name: "same length, sparse bits",
x: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}},
y: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}},
want: 1,
},
{
name: "max length difference",
x: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}},
y: BitArray{len: 1, words: [4]uint64{0x1, 0, 0, 0}},
want: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.x.Cmp(&tt.y)
if got != tt.want {
t.Errorf("Cmp() = %v, want %v", got, tt.want)
}

// Test anti-symmetry: if x.Cmp(y) = z then y.Cmp(x) = -z
gotReverse := tt.y.Cmp(&tt.x)
if gotReverse != -tt.want {
t.Errorf("Reverse Cmp() = %v, want %v", gotReverse, -tt.want)
}

// Test transitivity with self: x.Cmp(x) should always be 0
if tt.x.Cmp(&tt.x) != 0 {
t.Error("Self Cmp() != 0")
}
})
}
}

0 comments on commit 38d930b

Please sign in to comment.