Skip to content

Commit f8288e3

Browse files
ArtiomTrsauliusgrigaitis
authored andcommitted
Added ability to parametrize window size
1 parent 4e7b11a commit f8288e3

File tree

6 files changed

+212
-115
lines changed

6 files changed

+212
-115
lines changed

kzg/src/msm/arkmsm/arkmsm_msm.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@ impl VariableBaseMSM {
1515
/// on a Ubuntu 20.04.2 LTS server with AMD EPYC 7282 16-Core CPU
1616
/// and 128G memory, the optimal performance may vary on a different
1717
/// configuration.
18-
const fn get_opt_window_size(k: u32) -> u32 {
19-
match k {
20-
0..=9 => 8,
21-
10..=12 => 10,
22-
13..=14 => 12,
23-
15..=19 => 13,
24-
20..=22 => 15,
25-
23.. => 16,
26-
}
18+
fn get_opt_window_size(k: u32) -> u32 {
19+
option_env!("WINDOW_SIZE")
20+
.and_then(|v| v.parse().ok())
21+
.unwrap_or({
22+
match k {
23+
0..=9 => 8,
24+
10..=12 => 10,
25+
13..=14 => 12,
26+
15..=19 => 13,
27+
20..=22 => 15,
28+
23.. => 16,
29+
}
30+
})
2731
}
2832

2933
pub fn msm_slice(mut scalar: Scalar256, slices: &mut [u32], window_bits: u32) {

kzg/src/msm/bgmw.rs

Lines changed: 84 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -99,73 +99,100 @@ const fn get_sequential_window_size(window: BgmwWindow) -> usize {
9999
/// 2^w - 2 - computing total bucket sum (bucket aggregation). Total number of buckets (scratch size) is 2^(w-1).
100100
/// Adding each point to total bucket sum requires 2 point addition operations, so 2 * 2^(w-1) = 2^w.
101101
#[allow(unused)]
102-
const fn bgmw_window_size(npoints: usize) -> usize {
103-
let wbits = num_bits(npoints);
104-
105-
match (wbits) {
106-
1 => 4,
107-
2..=3 => 5,
108-
4 => 6,
109-
5 => 7,
110-
6..=7 => 8,
111-
8 => 9,
112-
9..=10 => 10,
113-
11 => 11,
114-
12 => 12,
115-
13..=14 => 13,
116-
15..=16 => 15,
117-
17 => 16,
118-
18..=19 => 17,
119-
20 => 19,
120-
21..=22 => 20,
121-
23..=24 => 22,
122-
25..=26 => 24,
123-
27..=29 => 26,
124-
30..=32 => 29,
125-
33..=37 => 32,
126-
_ => 37,
127-
}
102+
fn bgmw_window_size(npoints: usize) -> usize {
103+
option_env!("WINDOW_SIZE")
104+
.map(|v| {
105+
v.parse()
106+
.expect("WINDOW_SIZE environment variable must be valid number")
107+
})
108+
.unwrap_or({
109+
let wbits = num_bits(npoints);
110+
111+
match (wbits) {
112+
1 => 4,
113+
2..=3 => 5,
114+
4 => 6,
115+
5 => 7,
116+
6..=7 => 8,
117+
8 => 9,
118+
9..=10 => 10,
119+
11 => 11,
120+
12 => 12,
121+
13..=14 => 13,
122+
15..=16 => 15,
123+
17 => 16,
124+
18..=19 => 17,
125+
20 => 19,
126+
21..=22 => 20,
127+
23..=24 => 22,
128+
25..=26 => 24,
129+
27..=29 => 26,
130+
30..=32 => 29,
131+
33..=37 => 32,
132+
_ => 37,
133+
}
134+
})
128135
}
129136

130137
#[cfg(feature = "parallel")]
131-
const fn bgmw_parallel_window_size(npoints: usize, ncpus: usize) -> (usize, usize, usize) {
132-
let mut min_ops = usize::MAX;
133-
let mut opt = 0;
134-
135-
let mut win = 2;
136-
while win <= 40 {
137-
let ops = (1 << win) + (255usize.div_ceil(win).div_ceil(ncpus) * npoints) - 2;
138-
if min_ops >= ops {
139-
min_ops = ops;
140-
opt = win;
141-
}
142-
win += 1;
143-
}
138+
#[allow(clippy::option_env_unwrap)]
139+
fn bgmw_parallel_window_size(npoints: usize, ncpus: usize) -> (usize, usize, usize) {
140+
option_env!("WINDOW_NX")
141+
.and_then(|v| v.parse().ok())
142+
.map(|nx| {
143+
let wnd = option_env!("WINDOW_SIZE")
144+
.expect(
145+
"Unable to use BGMW: when specifying WINDOW_NX environment \
146+
variable, please also specify WINDOW_SIZE",
147+
)
148+
.parse()
149+
.expect("WINDOW_SIZE environment variable must be valid number");
150+
151+
(
152+
nx,
153+
255usize.div_ceil(wnd) + is_zero((NBITS % wnd) as u64) as usize,
154+
wnd,
155+
)
156+
})
157+
.unwrap_or({
158+
let mut min_ops = usize::MAX;
159+
let mut opt = 0;
160+
161+
let mut win = 2;
162+
while win <= 40 {
163+
let ops = (1 << win) + (255usize.div_ceil(win).div_ceil(ncpus) * npoints) - 2;
164+
if min_ops >= ops {
165+
min_ops = ops;
166+
opt = win;
167+
}
168+
win += 1;
169+
}
144170

145-
let mut mult = 1;
171+
let mut mult = 1;
146172

147-
let mut opt_x = 1;
173+
let mut opt_x = 1;
148174

149-
while mult <= 8 {
150-
let nx = ncpus * mult;
151-
let wnd = bgmw_window_size(npoints / nx);
175+
while mult <= 8 {
176+
let nx = ncpus * mult;
177+
let wnd = bgmw_window_size(npoints / nx);
152178

153-
let ops = mult * 255usize.div_ceil(wnd) * npoints.div_ceil(nx) + (1 << wnd) - 2;
179+
let ops = mult * 255usize.div_ceil(wnd) * npoints.div_ceil(nx) + (1 << wnd) - 2;
154180

155-
if min_ops > ops {
156-
min_ops = ops;
157-
opt = wnd;
158-
opt_x = nx;
159-
}
181+
if min_ops > ops {
182+
min_ops = ops;
183+
opt = wnd;
184+
opt_x = nx;
185+
}
160186

161-
mult += 1;
162-
}
187+
mult += 1;
188+
}
163189

164-
(
165-
opt_x,
166-
255usize.div_ceil(opt) + is_zero((NBITS % opt) as u64) as usize,
167-
opt,
168-
)
190+
(
191+
opt_x,
192+
255usize.div_ceil(opt) + is_zero((NBITS % opt) as u64) as usize,
193+
opt,
194+
)
195+
})
169196
}
170197

171198
impl<
Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,47 @@
11
use crate::msm::pippenger_utils::num_bits;
22

3-
pub const fn breakdown(window: usize, ncpus: usize) -> (usize, usize, usize) {
3+
pub fn breakdown(window: usize, ncpus: usize) -> (usize, usize, usize) {
44
const NBITS: usize = 255;
5-
let mut nx: usize;
6-
let mut wnd: usize;
75

8-
if NBITS > window * ncpus {
9-
nx = 1;
10-
wnd = num_bits(ncpus / 4);
11-
if (window + wnd) > 18 {
12-
wnd = window - wnd;
13-
} else {
14-
wnd = (NBITS / window).div_ceil(ncpus);
15-
if (NBITS / (window + 1)).div_ceil(ncpus) < wnd {
16-
wnd = window + 1;
6+
option_env!("WINDOW_NX")
7+
.map(|v| {
8+
v.parse()
9+
.expect("WINDOW_NX environment variable must be valid number")
10+
})
11+
.map(|nx| {
12+
let ny = NBITS / window + 1;
13+
(nx, ny, NBITS / ny + 1)
14+
})
15+
.unwrap_or({
16+
let mut nx: usize;
17+
let mut wnd: usize;
18+
19+
if NBITS > window * ncpus {
20+
nx = 1;
21+
wnd = num_bits(ncpus / 4);
22+
if (window + wnd) > 18 {
23+
wnd = window - wnd;
24+
} else {
25+
wnd = (NBITS / window).div_ceil(ncpus);
26+
if (NBITS / (window + 1)).div_ceil(ncpus) < wnd {
27+
wnd = window + 1;
28+
} else {
29+
wnd = window;
30+
}
31+
}
1732
} else {
18-
wnd = window;
33+
nx = 2;
34+
wnd = window - 2;
35+
while (NBITS / wnd + 1) * nx < ncpus {
36+
nx += 1;
37+
wnd = window - num_bits(3 * nx / 2);
38+
}
39+
nx -= 1;
40+
wnd = window - num_bits(3 * nx / 2);
1941
}
20-
}
21-
} else {
22-
nx = 2;
23-
wnd = window - 2;
24-
while (NBITS / wnd + 1) * nx < ncpus {
25-
nx += 1;
26-
wnd = window - num_bits(3 * nx / 2);
27-
}
28-
nx -= 1;
29-
wnd = window - num_bits(3 * nx / 2);
30-
}
31-
let ny = NBITS / wnd + 1;
32-
wnd = NBITS / ny + 1;
42+
let ny = NBITS / wnd + 1;
43+
wnd = NBITS / ny + 1;
3344

34-
(nx, ny, wnd)
45+
(nx, ny, wnd)
46+
})
3547
}

kzg/src/msm/pippenger_utils.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,23 @@ pub const fn num_bits(l: usize) -> usize {
297297
/// Adding each point to total bucket sum requires 2 point addition operations, so 2 * 2^(w-1) = 2^w.
298298
/// w + 1 - each bucket sum must be multiplied by 2^w. To do this, we need w doublings. Adding this sum to the
299299
/// total requires one more point addition, hence +1.
300-
pub const fn pippenger_window_size(npoints: usize) -> usize {
301-
let wbits = num_bits(npoints);
302-
303-
if wbits > 13 {
304-
return wbits - 4;
305-
}
306-
if wbits > 5 {
307-
return wbits - 3;
308-
}
309-
2
300+
pub fn pippenger_window_size(npoints: usize) -> usize {
301+
option_env!("WINDOW_SIZE")
302+
.map(|v| {
303+
v.parse()
304+
.expect("WINDOW_SIZE environment variable must be valid number")
305+
})
306+
.unwrap_or({
307+
let wbits = num_bits(npoints);
308+
309+
if wbits > 13 {
310+
return wbits - 4;
311+
}
312+
if wbits > 5 {
313+
return wbits - 3;
314+
}
315+
2
316+
})
310317
}
311318

312319
#[cfg(test)]

kzg/src/msm/wbits.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ use core::{marker::PhantomData, ops::Neg};
33

44
use crate::{Fr, G1Affine, G1Fp, G1GetFp, G1Mul, G1ProjAddAffine, G1};
55

6-
const WBITS: usize = 8;
7-
86
#[derive(Debug, Clone)]
97
pub struct WbitsTable<TFr, TG1, TG1Fp, TG1Affine, TG1ProjAddAffine>
108
where
@@ -26,6 +24,15 @@ where
2624
g1_affine_add_marker: PhantomData<TG1ProjAddAffine>,
2725
}
2826

27+
fn get_window_size() -> usize {
28+
option_env!("WINDOW_SIZE")
29+
.map(|v| {
30+
v.parse()
31+
.expect("WINDOW_SIZE environment variable must be valid number")
32+
})
33+
.unwrap_or(8)
34+
}
35+
2936
// Code was taken from: https://github.com/privacy-scaling-explorations/halo2curves/blob/b753a832e92d5c86c5c997327a9cf9de86a18851/src/msm.rs#L13
3037
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
3138
// Booth encoding:
@@ -294,13 +301,13 @@ impl<
294301
let mut table = Vec::new();
295302

296303
table
297-
.try_reserve_exact(points.len() * (1 << (WBITS - 1)))
304+
.try_reserve_exact(points.len() * (1 << (get_window_size() - 1)))
298305
.map_err(|_| "WBITS precomputation table is too large".to_string())?;
299306

300307
for point in points {
301308
let mut current = point.clone();
302309

303-
for _ in 0..(1 << (WBITS - 1)) {
310+
for _ in 0..(1 << (get_window_size() - 1)) {
304311
table.push(TG1Affine::into_affine(&current));
305312
current = current.add_or_dbl(point);
306313
}
@@ -329,13 +336,13 @@ impl<
329336
for row in matrix {
330337
let mut temp_table = Vec::new();
331338
temp_table
332-
.try_reserve_exact(row.len() * (1 << (WBITS - 1)))
339+
.try_reserve_exact(row.len() * (1 << (get_window_size() - 1)))
333340
.map_err(|_| "WBITS precomputation table is too large".to_owned())?;
334341

335342
for point in row {
336343
let mut current = point.clone();
337344

338-
for _ in 0..(1 << (WBITS - 1)) {
345+
for _ in 0..(1 << (get_window_size() - 1)) {
339346
temp_table.push(TG1Affine::into_affine(&current));
340347
current = current.add_or_dbl(point);
341348
}
@@ -362,15 +369,16 @@ impl<
362369
fn multiply_sequential_raw(bases: &[TG1Affine], scalars: &[TFr]) -> TG1 {
363370
let scalars = scalars.iter().map(TFr::to_scalar).collect::<Vec<_>>();
364371

365-
let number_of_windows = 255 / WBITS + 1;
372+
let number_of_windows = 255 / get_window_size() + 1;
366373
let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows];
367374

368375
for window_idx in 0..windows_of_points.len() {
369376
for (scalar_idx, scalar_bytes) in scalars.iter().enumerate() {
370-
let sub_table =
371-
&bases[scalar_idx * (1 << (WBITS - 1))..(scalar_idx + 1) * (1 << (WBITS - 1))];
377+
let sub_table = &bases[scalar_idx * (1 << (get_window_size() - 1))
378+
..(scalar_idx + 1) * (1 << (get_window_size() - 1))];
372379

373-
let point_idx = get_booth_index(window_idx, WBITS, scalar_bytes.as_u8());
380+
let point_idx =
381+
get_booth_index(window_idx, get_window_size(), scalar_bytes.as_u8());
374382

375383
if point_idx == 0 {
376384
continue;
@@ -396,7 +404,7 @@ impl<
396404
let mut result: TG1 = accumulated_points.last().unwrap().clone();
397405
for point in accumulated_points.into_iter().rev().skip(1) {
398406
// Double the result 'wbits' times
399-
for _ in 0..WBITS {
407+
for _ in 0..get_window_size() {
400408
result = result.dbl();
401409
}
402410
// Add the accumulated point for this window

0 commit comments

Comments
 (0)