1
1
package bandersnatch
2
2
3
3
import (
4
- "math"
5
4
"math/big"
6
5
7
6
"github.com/consensys/gnark-crypto/ecc"
@@ -30,14 +29,13 @@ func (p *PointProj) phi(p1 *PointProj) *PointProj {
30
29
return p
31
30
}
32
31
33
- // ScalarMultiplication scalar multiplication ( GLV) of a point
32
+ // scalarMulGLV is the GLV scalar multiplication of a point
34
33
// p1 in projective coordinates with a scalar in big.Int
35
34
func (p * PointProj ) scalarMulGLV (p1 * PointProj , scalar * big.Int ) * PointProj {
36
35
37
36
initOnce .Do (initCurveParams )
38
37
39
38
var table [15 ]PointProj
40
- var zero big.Int
41
39
var res PointProj
42
40
var k1 , k2 fr.Element
43
41
@@ -50,38 +48,45 @@ func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj {
50
48
// split the scalar, modifies +-p1, phi(p1) accordingly
51
49
k := ecc .SplitScalar (scalar , & curveParams .glvBasis )
52
50
53
- if k [0 ].Cmp ( & zero ) == - 1 {
51
+ if k [0 ].Sign ( ) == - 1 {
54
52
k [0 ].Neg (& k [0 ])
55
53
table [0 ].Neg (& table [0 ])
56
54
}
57
- if k [1 ].Cmp ( & zero ) == - 1 {
55
+ if k [1 ].Sign ( ) == - 1 {
58
56
k [1 ].Neg (& k [1 ])
59
57
table [3 ].Neg (& table [3 ])
60
58
}
61
59
62
60
// precompute table (2 bits sliding window)
63
61
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
64
62
table [1 ].Double (& table [0 ])
65
- table [2 ].Set ( & table [ 1 ]). Add (& table [2 ], & table [0 ])
66
- table [4 ].Set ( & table [ 3 ]). Add (& table [4 ], & table [0 ])
67
- table [5 ].Set ( & table [ 3 ]). Add (& table [5 ], & table [1 ])
68
- table [6 ].Set ( & table [ 3 ]). Add (& table [6 ], & table [2 ])
63
+ table [2 ].Add (& table [1 ], & table [0 ])
64
+ table [4 ].Add (& table [3 ], & table [0 ])
65
+ table [5 ].Add (& table [3 ], & table [1 ])
66
+ table [6 ].Add (& table [3 ], & table [2 ])
69
67
table [7 ].Double (& table [3 ])
70
- table [8 ].Set (& table [7 ]).Add (& table [8 ], & table [0 ])
71
- table [9 ].Set (& table [7 ]).Add (& table [9 ], & table [1 ])
72
- table [10 ].Set (& table [7 ]).Add (& table [10 ], & table [2 ])
73
- table [11 ].Set (& table [7 ]).Add (& table [11 ], & table [3 ])
74
- table [12 ].Set (& table [11 ]).Add (& table [12 ], & table [0 ])
75
- table [13 ].Set (& table [11 ]).Add (& table [13 ], & table [1 ])
76
- table [14 ].Set (& table [11 ]).Add (& table [14 ], & table [2 ])
77
-
78
- // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
68
+ table [8 ].Add (& table [7 ], & table [0 ])
69
+ table [9 ].Add (& table [7 ], & table [1 ])
70
+ table [10 ].Add (& table [7 ], & table [2 ])
71
+ table [11 ].Add (& table [7 ], & table [3 ])
72
+ table [12 ].Add (& table [11 ], & table [0 ])
73
+ table [13 ].Add (& table [11 ], & table [1 ])
74
+ table [14 ].Add (& table [11 ], & table [2 ])
75
+
76
+ // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
77
+ // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
79
78
k1 = k1 .SetBigInt (& k [0 ]).Bits ()
80
79
k2 = k2 .SetBigInt (& k [1 ]).Bits ()
81
80
82
- // loop starts from len(k1)/2 due to the bounds
83
- // fr.Limbs == Order.limbs
84
- for i := int (math .Ceil (fr .Limbs / 2. - 1 )); i >= 0 ; i -- {
81
+ // we don't target constant-timeness so we check first if we increase the bounds or not
82
+ maxBit := k1 .BitLen ()
83
+ if k2 .BitLen () > maxBit {
84
+ maxBit = k2 .BitLen ()
85
+ }
86
+ hiWordIndex := (maxBit - 1 ) / 64
87
+
88
+ // loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
89
+ for i := hiWordIndex ; i >= 0 ; i -- {
85
90
mask := uint64 (3 ) << 62
86
91
for j := 0 ; j < 32 ; j ++ {
87
92
res .Double (& res ).Double (& res )
@@ -121,13 +126,13 @@ func (p *PointExtended) phi(p1 *PointExtended) *PointExtended {
121
126
return p
122
127
}
123
128
124
- // ScalarMultiplication scalar multiplication ( GLV) of a point
129
+ // scalarMulGLV is the GLV scalar multiplication of a point
125
130
// p1 in projective coordinates with a scalar in big.Int
126
131
func (p * PointExtended ) scalarMulGLV (p1 * PointExtended , scalar * big.Int ) * PointExtended {
132
+
127
133
initOnce .Do (initCurveParams )
128
134
129
135
var table [15 ]PointExtended
130
- var zero big.Int
131
136
var res PointExtended
132
137
var k1 , k2 fr.Element
133
138
@@ -140,38 +145,45 @@ func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointE
140
145
// split the scalar, modifies +-p1, phi(p1) accordingly
141
146
k := ecc .SplitScalar (scalar , & curveParams .glvBasis )
142
147
143
- if k [0 ].Cmp ( & zero ) == - 1 {
148
+ if k [0 ].Sign ( ) == - 1 {
144
149
k [0 ].Neg (& k [0 ])
145
150
table [0 ].Neg (& table [0 ])
146
151
}
147
- if k [1 ].Cmp ( & zero ) == - 1 {
152
+ if k [1 ].Sign ( ) == - 1 {
148
153
k [1 ].Neg (& k [1 ])
149
154
table [3 ].Neg (& table [3 ])
150
155
}
151
156
152
157
// precompute table (2 bits sliding window)
153
158
// table[b3b2b1b0-1] = b3b2*phi(p1) + b1b0*p1 if b3b2b1b0 != 0
154
159
table [1 ].Double (& table [0 ])
155
- table [2 ].Set ( & table [ 1 ]). Add (& table [2 ], & table [0 ])
156
- table [4 ].Set ( & table [ 3 ]). Add (& table [4 ], & table [0 ])
157
- table [5 ].Set ( & table [ 3 ]). Add (& table [5 ], & table [1 ])
158
- table [6 ].Set ( & table [ 3 ]). Add (& table [6 ], & table [2 ])
160
+ table [2 ].Add (& table [1 ], & table [0 ])
161
+ table [4 ].Add (& table [3 ], & table [0 ])
162
+ table [5 ].Add (& table [3 ], & table [1 ])
163
+ table [6 ].Add (& table [3 ], & table [2 ])
159
164
table [7 ].Double (& table [3 ])
160
- table [8 ].Set (& table [7 ]).Add (& table [8 ], & table [0 ])
161
- table [9 ].Set (& table [7 ]).Add (& table [9 ], & table [1 ])
162
- table [10 ].Set (& table [7 ]).Add (& table [10 ], & table [2 ])
163
- table [11 ].Set (& table [7 ]).Add (& table [11 ], & table [3 ])
164
- table [12 ].Set (& table [11 ]).Add (& table [12 ], & table [0 ])
165
- table [13 ].Set (& table [11 ]).Add (& table [13 ], & table [1 ])
166
- table [14 ].Set (& table [11 ]).Add (& table [14 ], & table [2 ])
167
-
168
- // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max
165
+ table [8 ].Add (& table [7 ], & table [0 ])
166
+ table [9 ].Add (& table [7 ], & table [1 ])
167
+ table [10 ].Add (& table [7 ], & table [2 ])
168
+ table [11 ].Add (& table [7 ], & table [3 ])
169
+ table [12 ].Add (& table [11 ], & table [0 ])
170
+ table [13 ].Add (& table [11 ], & table [1 ])
171
+ table [14 ].Add (& table [11 ], & table [2 ])
172
+
173
+ // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max
174
+ // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift
169
175
k1 = k1 .SetBigInt (& k [0 ]).Bits ()
170
176
k2 = k2 .SetBigInt (& k [1 ]).Bits ()
171
177
172
- // loop starts from len(k1)/2 due to the bounds
173
- // fr.Limbs == Order.limbs
174
- for i := int (math .Ceil (fr .Limbs / 2. - 1 )); i >= 0 ; i -- {
178
+ // we don't target constant-timeness so we check first if we increase the bounds or not
179
+ maxBit := k1 .BitLen ()
180
+ if k2 .BitLen () > maxBit {
181
+ maxBit = k2 .BitLen ()
182
+ }
183
+ hiWordIndex := (maxBit - 1 ) / 64
184
+
185
+ // loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds
186
+ for i := hiWordIndex ; i >= 0 ; i -- {
175
187
mask := uint64 (3 ) << 62
176
188
for j := 0 ; j < 32 ; j ++ {
177
189
res .Double (& res ).Double (& res )
0 commit comments