Skip to content

Commit 6cf8884

Browse files
authored
fix(bandersnatch): GLV bounds + test (#516)
1 parent 7e5f929 commit 6cf8884

File tree

14 files changed

+377
-92
lines changed

14 files changed

+377
-92
lines changed

ecc/bls12-377/twistededwards/point.go

+18-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecc/bls12-378/twistededwards/point.go

+18-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecc/bls12-381/bandersnatch/endomorpism.go

+53-41
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package bandersnatch
22

33
import (
4-
"math"
54
"math/big"
65

76
"github.com/consensys/gnark-crypto/ecc"
@@ -30,14 +29,13 @@ func (p *PointProj) phi(p1 *PointProj) *PointProj {
3029
return p
3130
}
3231

33-
// ScalarMultiplication scalar multiplication (GLV) of a point
32+
// scalarMulGLV is the GLV scalar multiplication of a point
3433
// p1 in projective coordinates with a scalar in big.Int
3534
func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj {
3635

3736
initOnce.Do(initCurveParams)
3837

3938
var table [15]PointProj
40-
var zero big.Int
4139
var res PointProj
4240
var k1, k2 fr.Element
4341

@@ -50,38 +48,45 @@ func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj {
5048
// split the scalar, modifies +-p1, phi(p1) accordingly
5149
k := ecc.SplitScalar(scalar, &curveParams.glvBasis)
5250

53-
if k[0].Cmp(&zero) == -1 {
51+
if k[0].Sign() == -1 {
5452
k[0].Neg(&k[0])
5553
table[0].Neg(&table[0])
5654
}
57-
if k[1].Cmp(&zero) == -1 {
55+
if k[1].Sign() == -1 {
5856
k[1].Neg(&k[1])
5957
table[3].Neg(&table[3])
6058
}
6159

6260
// precompute table (2 bits sliding window)
6361
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
6462
table[1].Double(&table[0])
65-
table[2].Set(&table[1]).Add(&table[2], &table[0])
66-
table[4].Set(&table[3]).Add(&table[4], &table[0])
67-
table[5].Set(&table[3]).Add(&table[5], &table[1])
68-
table[6].Set(&table[3]).Add(&table[6], &table[2])
63+
table[2].Add(&table[1], &table[0])
64+
table[4].Add(&table[3], &table[0])
65+
table[5].Add(&table[3], &table[1])
66+
table[6].Add(&table[3], &table[2])
6967
table[7].Double(&table[3])
70-
table[8].Set(&table[7]).Add(&table[8], &table[0])
71-
table[9].Set(&table[7]).Add(&table[9], &table[1])
72-
table[10].Set(&table[7]).Add(&table[10], &table[2])
73-
table[11].Set(&table[7]).Add(&table[11], &table[3])
74-
table[12].Set(&table[11]).Add(&table[12], &table[0])
75-
table[13].Set(&table[11]).Add(&table[13], &table[1])
76-
table[14].Set(&table[11]).Add(&table[14], &table[2])
77-
78-
// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
68+
table[8].Add(&table[7], &table[0])
69+
table[9].Add(&table[7], &table[1])
70+
table[10].Add(&table[7], &table[2])
71+
table[11].Add(&table[7], &table[3])
72+
table[12].Add(&table[11], &table[0])
73+
table[13].Add(&table[11], &table[1])
74+
table[14].Add(&table[11], &table[2])
75+
76+
// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
77+
// this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
7978
k1 = k1.SetBigInt(&k[0]).Bits()
8079
k2 = k2.SetBigInt(&k[1]).Bits()
8180

82-
// loop starts from len(k1)/2 due to the bounds
83-
// fr.Limbs == Order.limbs
84-
for i := int(math.Ceil(fr.Limbs/2. - 1)); i >= 0; i-- {
81+
// we don't target constant-timeness so we check first if we increase the bounds or not
82+
maxBit := k1.BitLen()
83+
if k2.BitLen() > maxBit {
84+
maxBit = k2.BitLen()
85+
}
86+
hiWordIndex := (maxBit - 1) / 64
87+
88+
// loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
89+
for i := hiWordIndex; i >= 0; i-- {
8590
mask := uint64(3) << 62
8691
for j := 0; j < 32; j++ {
8792
res.Double(&res).Double(&res)
@@ -121,13 +126,13 @@ func (p *PointExtended) phi(p1 *PointExtended) *PointExtended {
121126
return p
122127
}
123128

124-
// ScalarMultiplication scalar multiplication (GLV) of a point
129+
// scalarMulGLV is the GLV scalar multiplication of a point
125130
// p1 in projective coordinates with a scalar in big.Int
126131
func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointExtended {
132+
127133
initOnce.Do(initCurveParams)
128134

129135
var table [15]PointExtended
130-
var zero big.Int
131136
var res PointExtended
132137
var k1, k2 fr.Element
133138

@@ -140,38 +145,45 @@ func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointE
140145
// split the scalar, modifies +-p1, phi(p1) accordingly
141146
k := ecc.SplitScalar(scalar, &curveParams.glvBasis)
142147

143-
if k[0].Cmp(&zero) == -1 {
148+
if k[0].Sign() == -1 {
144149
k[0].Neg(&k[0])
145150
table[0].Neg(&table[0])
146151
}
147-
if k[1].Cmp(&zero) == -1 {
152+
if k[1].Sign() == -1 {
148153
k[1].Neg(&k[1])
149154
table[3].Neg(&table[3])
150155
}
151156

152157
// precompute table (2 bits sliding window)
153158
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
154159
table[1].Double(&table[0])
155-
table[2].Set(&table[1]).Add(&table[2], &table[0])
156-
table[4].Set(&table[3]).Add(&table[4], &table[0])
157-
table[5].Set(&table[3]).Add(&table[5], &table[1])
158-
table[6].Set(&table[3]).Add(&table[6], &table[2])
160+
table[2].Add(&table[1], &table[0])
161+
table[4].Add(&table[3], &table[0])
162+
table[5].Add(&table[3], &table[1])
163+
table[6].Add(&table[3], &table[2])
159164
table[7].Double(&table[3])
160-
table[8].Set(&table[7]).Add(&table[8], &table[0])
161-
table[9].Set(&table[7]).Add(&table[9], &table[1])
162-
table[10].Set(&table[7]).Add(&table[10], &table[2])
163-
table[11].Set(&table[7]).Add(&table[11], &table[3])
164-
table[12].Set(&table[11]).Add(&table[12], &table[0])
165-
table[13].Set(&table[11]).Add(&table[13], &table[1])
166-
table[14].Set(&table[11]).Add(&table[14], &table[2])
167-
168-
// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
165+
table[8].Add(&table[7], &table[0])
166+
table[9].Add(&table[7], &table[1])
167+
table[10].Add(&table[7], &table[2])
168+
table[11].Add(&table[7], &table[3])
169+
table[12].Add(&table[11], &table[0])
170+
table[13].Add(&table[11], &table[1])
171+
table[14].Add(&table[11], &table[2])
172+
173+
// bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
174+
// this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
169175
k1 = k1.SetBigInt(&k[0]).Bits()
170176
k2 = k2.SetBigInt(&k[1]).Bits()
171177

172-
// loop starts from len(k1)/2 due to the bounds
173-
// fr.Limbs == Order.limbs
174-
for i := int(math.Ceil(fr.Limbs/2. - 1)); i >= 0; i-- {
178+
// we don't target constant-timeness so we check first if we increase the bounds or not
179+
maxBit := k1.BitLen()
180+
if k2.BitLen() > maxBit {
181+
maxBit = k2.BitLen()
182+
}
183+
hiWordIndex := (maxBit - 1) / 64
184+
185+
// loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
186+
for i := hiWordIndex; i >= 0; i-- {
175187
mask := uint64(3) << 62
176188
for j := 0; j < 32; j++ {
177189
res.Double(&res).Double(&res)

ecc/bls12-381/bandersnatch/point.go

+63
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecc/bls12-381/bandersnatch/point_test.go

+32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)