@@ -84,8 +84,8 @@ impl Poly2 {
84
84
}
85
85
86
86
/// Raises `x` to the power `n` using binary exponentiation,
87
- /// with (1 to 2)*lg(n) scalar multiplications.
88
- /// TODO: a consttime version of this would be awfully similar to a Montgomery ladder.
87
+ /// with ` (1 to 2)*lg(n)` scalar multiplications.
88
+ /// TODO: a consttime version of this would be similar to a Montgomery ladder.
89
89
pub fn scalar_exp_vartime ( x : & Scalar , mut n : u64 ) -> Scalar {
90
90
let mut result = Scalar :: one ( ) ;
91
91
let mut aux = * x; // x, x^2, x^4, x^8, ...
@@ -95,38 +95,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
95
95
result = result * aux;
96
96
}
97
97
n = n >> 1 ;
98
- aux = aux * aux; // FIXME: one unnecessary mult at the last step here!
98
+ if n > 0 {
99
+ aux = aux * aux;
100
+ }
99
101
}
100
102
result
101
103
}
102
104
103
- /// Takes the sum of all the powers of `x`, up to `n`
104
- /// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplcations and additions.
105
- /// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
106
- /// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
107
- pub fn sum_of_powers ( x : & Scalar , n : usize ) -> Scalar {
108
- if !n. is_power_of_two ( ) {
109
- return sum_of_powers_slow ( x, n) ;
110
- }
111
- if n == 0 || n == 1 {
112
- return Scalar :: from_u64 ( n as u64 ) ;
113
- }
114
- let mut m = n;
115
- let mut result = Scalar :: one ( ) + x;
116
- let mut factor = * x;
117
- while m > 2 {
118
- factor = factor * factor;
119
- result = result + factor * result;
120
- m = m / 2 ;
105
+ /// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\)
106
+ /// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret
107
+ /// and algorithm is fastest when \\(n\\) is the power of two.
108
+ ///
109
+ /// ### Algorithm overview
110
+ ///
111
+ /// First, let \\(n\\) be a power of two.
112
+ /// Then, we can divide the polynomial in two halves like so:
113
+ /// \\[
114
+ /// \begin{aligned}
115
+ /// S(n) &= (1+\dots+x^{n-1}) \\\\
116
+ /// &= (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) \\\\
117
+ /// &= s + x^{n/2} s.
118
+ /// \end{aligned}
119
+ /// \\]
120
+ /// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\).
121
+ /// Recursively, the total sum can be defined as:
122
+ /// \\[
123
+ /// \begin{aligned}
124
+ /// S(0) &= 0 \\\\
125
+ /// S(n) &= s_{\lg n} \\\\
126
+ /// s_0 &= 1 \\\\
127
+ /// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1}
128
+ /// \end{aligned}
129
+ /// \\]
130
+ /// This representation allows us to square \\(x\\) only \\(\lg n\\) times.
131
+ ///
132
+ /// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using
133
+ /// bits \\(b_i\\) in \\(\\{0,1\\}\\):
134
+ /// \\[
135
+ /// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1}
136
+ /// \\]
137
+ /// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)),
138
+ /// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar
139
+ /// an intermediate polynomial with \\(2^i\\) terms using the above algorithm,
140
+ /// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\)
141
+ /// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\).
142
+ ///
143
+ /// The full algorithm becomes:
144
+ /// \\[
145
+ /// \begin{aligned}
146
+ /// S(0) &= 0 \\\\
147
+ /// S(1) &= 1 \\\\
148
+ /// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\
149
+ /// &= S(i-1) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i
150
+ /// \end{aligned}
151
+ /// \\]
152
+ pub fn sum_of_powers ( x : & Scalar , mut n : usize ) -> Scalar {
153
+ let mut result = Scalar :: zero ( ) ;
154
+ let mut f = Scalar :: one ( ) ; // power of x to offset subsequent polynomials based on lower bits of n.
155
+ let mut s = Scalar :: one ( ) ; // power-of-two polynomial: 1, 1+x, 1+x+x^2+x^3, ...
156
+ let mut p = * x; // x, x^2, x^4, ..., x^{2^i}
157
+ while n > 0 {
158
+ // take a bit from n
159
+ let bit = n & 1 ;
160
+ n = n >> 1 ;
161
+
162
+ if bit == 1 {
163
+ // bits of `n` are not secret, so it's okay to be vartime because of `n` value.
164
+ result += f * s;
165
+ if n > 0 { // avoid multiplication if no bits left
166
+ f = f * p;
167
+ }
168
+ }
169
+ if n > 0 { // avoid multiplication if no bits left
170
+ s = s + p * s;
171
+ p = p * p;
172
+ }
121
173
}
122
174
result
123
175
}
124
176
125
- // takes the sum of all of the powers of x, up to n
126
- fn sum_of_powers_slow ( x : & Scalar , n : usize ) -> Scalar {
127
- exp_iter ( * x) . take ( n) . fold ( Scalar :: zero ( ) , |acc, x| acc + x)
128
- }
129
-
130
177
#[ cfg( test) ]
131
178
mod tests {
132
179
use super :: * ;
@@ -185,9 +232,14 @@ mod tests {
185
232
) ;
186
233
}
187
234
235
+ // takes the sum of all of the powers of x, up to n
236
+ fn sum_of_powers_slow ( x : & Scalar , n : usize ) -> Scalar {
237
+ exp_iter ( * x) . take ( n) . fold ( Scalar :: zero ( ) , |acc, x| acc + x)
238
+ }
239
+
188
240
#[ test]
189
- fn test_sum_of_powers ( ) {
190
- let x = Scalar :: from_u64 ( 10 ) ;
241
+ fn test_sum_of_powers_pow2 ( ) {
242
+ let x = Scalar :: from_u64 ( 1337133713371337 ) ;
191
243
assert_eq ! ( sum_of_powers_slow( & x, 0 ) , sum_of_powers( & x, 0 ) ) ;
192
244
assert_eq ! ( sum_of_powers_slow( & x, 1 ) , sum_of_powers( & x, 1 ) ) ;
193
245
assert_eq ! ( sum_of_powers_slow( & x, 2 ) , sum_of_powers( & x, 2 ) ) ;
@@ -199,14 +251,16 @@ mod tests {
199
251
}
200
252
201
253
#[ test]
202
- fn test_sum_of_powers_slow ( ) {
254
+ fn test_sum_of_powers_non_pow2 ( ) {
203
255
let x = Scalar :: from_u64 ( 10 ) ;
204
- assert_eq ! ( sum_of_powers_slow( & x, 0 ) , Scalar :: zero( ) ) ;
205
- assert_eq ! ( sum_of_powers_slow( & x, 1 ) , Scalar :: one( ) ) ;
206
- assert_eq ! ( sum_of_powers_slow( & x, 2 ) , Scalar :: from_u64( 11 ) ) ;
207
- assert_eq ! ( sum_of_powers_slow( & x, 3 ) , Scalar :: from_u64( 111 ) ) ;
208
- assert_eq ! ( sum_of_powers_slow( & x, 4 ) , Scalar :: from_u64( 1111 ) ) ;
209
- assert_eq ! ( sum_of_powers_slow( & x, 5 ) , Scalar :: from_u64( 11111 ) ) ;
210
- assert_eq ! ( sum_of_powers_slow( & x, 6 ) , Scalar :: from_u64( 111111 ) ) ;
256
+ assert_eq ! ( sum_of_powers( & x, 0 ) , Scalar :: zero( ) ) ;
257
+ assert_eq ! ( sum_of_powers( & x, 1 ) , Scalar :: one( ) ) ;
258
+ assert_eq ! ( sum_of_powers( & x, 2 ) , Scalar :: from_u64( 11 ) ) ;
259
+ assert_eq ! ( sum_of_powers( & x, 3 ) , Scalar :: from_u64( 111 ) ) ;
260
+ assert_eq ! ( sum_of_powers( & x, 4 ) , Scalar :: from_u64( 1111 ) ) ;
261
+ assert_eq ! ( sum_of_powers( & x, 5 ) , Scalar :: from_u64( 11111 ) ) ;
262
+ assert_eq ! ( sum_of_powers( & x, 6 ) , Scalar :: from_u64( 111111 ) ) ;
263
+ assert_eq ! ( sum_of_powers( & x, 7 ) , Scalar :: from_u64( 1111111 ) ) ;
264
+ assert_eq ! ( sum_of_powers( & x, 8 ) , Scalar :: from_u64( 11111111 ) ) ;
211
265
}
212
266
}
0 commit comments