Skip to content

Commit 612a107

Browse files
committed
Implement fast sum of powers for any n
1 parent 2aeebef commit 612a107

File tree

1 file changed

+90
-36
lines changed

1 file changed

+90
-36
lines changed

src/util.rs

+90-36
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ impl Poly2 {
8484
}
8585

8686
/// Raises `x` to the power `n` using binary exponentiation,
87-
/// with (1 to 2)*lg(n) scalar multiplications.
88-
/// TODO: a consttime version of this would be awfully similar to a Montgomery ladder.
87+
/// with `(1 to 2)*lg(n)` scalar multiplications.
88+
/// TODO: a consttime version of this would be similar to a Montgomery ladder.
8989
pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
9090
let mut result = Scalar::one();
9191
let mut aux = *x; // x, x^2, x^4, x^8, ...
@@ -95,38 +95,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
9595
result = result * aux;
9696
}
9797
n = n >> 1;
98-
aux = aux * aux; // FIXME: one unnecessary mult at the last step here!
98+
if n > 0 {
99+
aux = aux * aux;
100+
}
99101
}
100102
result
101103
}
102104

103-
/// Takes the sum of all the powers of `x`, up to `n`
104-
/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplcations and additions.
105-
/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
106-
/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
107-
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
108-
if !n.is_power_of_two() {
109-
return sum_of_powers_slow(x, n);
110-
}
111-
if n == 0 || n == 1 {
112-
return Scalar::from_u64(n as u64);
113-
}
114-
let mut m = n;
115-
let mut result = Scalar::one() + x;
116-
let mut factor = *x;
117-
while m > 2 {
118-
factor = factor * factor;
119-
result = result + factor * result;
120-
m = m / 2;
105+
/// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\)
106+
/// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret
107+
/// and algorithm is fastest when \\(n\\) is the power of two.
108+
///
109+
/// ### Algorithm overview
110+
///
111+
/// First, let \\(n\\) be a power of two.
112+
/// Then, we can divide the polynomial in two halves like so:
113+
/// \\[
114+
/// \begin{aligned}
115+
/// S(n) &= (1+\dots+x^{n-1}) \\\\
116+
/// &= (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) \\\\
117+
/// &= s + x^{n/2} s.
118+
/// \end{aligned}
119+
/// \\]
120+
/// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\).
121+
/// Recursively, the total sum can be defined as:
122+
/// \\[
123+
/// \begin{aligned}
124+
/// S(0) &= 0 \\\\
125+
/// S(n) &= s_{\lg n} \\\\
126+
/// s_0 &= 1 \\\\
127+
/// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1}
128+
/// \end{aligned}
129+
/// \\]
130+
/// This representation allows us to square \\(x\\) only \\(\lg n\\) times.
131+
///
132+
/// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using
133+
/// bits \\(b_i\\) in \\(\\{0,1\\}\\):
134+
/// \\[
135+
/// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1}
136+
/// \\]
137+
/// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)),
138+
/// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar
139+
/// an intermediate polynomial with \\(2^i\\) terms using the above algorithm,
140+
/// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\)
141+
/// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\).
142+
///
143+
/// The full algorithm becomes:
144+
/// \\[
145+
/// \begin{aligned}
146+
/// S(0) &= 0 \\\\
147+
/// S(1) &= 1 \\\\
148+
/// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\
149+
/// &= S(i-1) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i
150+
/// \end{aligned}
151+
/// \\]
152+
pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar {
153+
let mut result = Scalar::zero();
154+
let mut f = Scalar::one(); // power of x to offset subsequent polynomials based on lower bits of n.
155+
let mut s = Scalar::one(); // power-of-two polynomial: 1, 1+x, 1+x+x^2+x^3, ...
156+
let mut p = *x; // x, x^2, x^4, ..., x^{2^i}
157+
while n > 0 {
158+
// take a bit from n
159+
let bit = n & 1;
160+
n = n >> 1;
161+
162+
if bit == 1 {
163+
// bits of `n` are not secret, so it's okay to be vartime because of `n` value.
164+
result += f * s;
165+
if n > 0 { // avoid multiplication if no bits left
166+
f = f * p;
167+
}
168+
}
169+
if n > 0 { // avoid multiplication if no bits left
170+
s = s + p * s;
171+
p = p * p;
172+
}
121173
}
122174
result
123175
}
124176

125-
// takes the sum of all of the powers of x, up to n
126-
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
127-
exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x)
128-
}
129-
130177
#[cfg(test)]
131178
mod tests {
132179
use super::*;
@@ -185,9 +232,14 @@ mod tests {
185232
);
186233
}
187234

235+
// takes the sum of all of the powers of x, up to n
236+
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
237+
exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x)
238+
}
239+
188240
#[test]
189-
fn test_sum_of_powers() {
190-
let x = Scalar::from_u64(10);
241+
fn test_sum_of_powers_pow2() {
242+
let x = Scalar::from_u64(1337133713371337);
191243
assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0));
192244
assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1));
193245
assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2));
@@ -199,14 +251,16 @@ mod tests {
199251
}
200252

201253
#[test]
202-
fn test_sum_of_powers_slow() {
254+
fn test_sum_of_powers_non_pow2() {
203255
let x = Scalar::from_u64(10);
204-
assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero());
205-
assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one());
206-
assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from_u64(11));
207-
assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from_u64(111));
208-
assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from_u64(1111));
209-
assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from_u64(11111));
210-
assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from_u64(111111));
256+
assert_eq!(sum_of_powers(&x, 0), Scalar::zero());
257+
assert_eq!(sum_of_powers(&x, 1), Scalar::one());
258+
assert_eq!(sum_of_powers(&x, 2), Scalar::from_u64(11));
259+
assert_eq!(sum_of_powers(&x, 3), Scalar::from_u64(111));
260+
assert_eq!(sum_of_powers(&x, 4), Scalar::from_u64(1111));
261+
assert_eq!(sum_of_powers(&x, 5), Scalar::from_u64(11111));
262+
assert_eq!(sum_of_powers(&x, 6), Scalar::from_u64(111111));
263+
assert_eq!(sum_of_powers(&x, 7), Scalar::from_u64(1111111));
264+
assert_eq!(sum_of_powers(&x, 8), Scalar::from_u64(11111111));
211265
}
212266
}

0 commit comments

Comments
 (0)