Skip to content

Commit

Permalink
Twist, d=1 (#573)
Browse files Browse the repository at this point in the history
* Val-evaluation sumcheck

* Read-checking sumcheck

* Combined read/write-checking sumcheck

* Twist e2e

* Benchmark, tracing spans, and optimizations

* use Zipf distribution in benchmark

* Optimize ra/wa materialization and memory allocations

* Switch binding order for second half of Twist read/write-checking sumcheck

* Check for 0s in second half of sumcheck

* Preemptively multiply eq(r, x) by z

* Avoid unnecessary memcpy when materializing val
  • Loading branch information
moodlezoup authored Feb 20, 2025
1 parent c93148d commit dae559d
Show file tree
Hide file tree
Showing 6 changed files with 1,404 additions and 20 deletions.
13 changes: 12 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ tokio = { version = "1.38.0", optional = true }
alloy-primitives = "0.7.6"
alloy-sol-types = "0.7.6"
once_cell = "1.19.0"
rand_distr = "0.4.3"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
73 changes: 73 additions & 0 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::hyperkzg::HyperKZG;
use crate::poly::commitment::zeromorph::Zeromorph;
use crate::subprotocols::shout::ShoutProof;
use crate::subprotocols::twist::{TwistAlgorithm, TwistProof};
use crate::utils::math::Math;
use crate::utils::transcript::{KeccakTranscript, Transcript};
use ark_bn254::{Bn254, Fr};
use ark_std::test_rng;
use rand_core::RngCore;
use rand_distr::{Distribution, Zipf};
use serde::Serialize;

#[derive(Debug, Copy, Clone, clap::ValueEnum)]
Expand All @@ -26,6 +28,7 @@ pub enum BenchType {
Sha3,
Sha2Chain,
Shout,
Twist,
}

#[allow(unreachable_patterns)] // good errors on new BenchTypes
Expand All @@ -47,6 +50,7 @@ pub fn benchmarks(
fibonacci::<Fr, Zeromorph<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
BenchType::Twist => twist::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
PCSType::HyperKZG => match bench_type {
Expand All @@ -59,6 +63,7 @@ pub fn benchmarks(
fibonacci::<Fr, HyperKZG<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
BenchType::Twist => twist::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
_ => panic!("PCS Type does not have a mapping"),
Expand Down Expand Up @@ -105,6 +110,74 @@ where
tasks
}

fn twist<F, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
ProofTranscript: Transcript,
{
let small_value_lookup_tables = F::compute_lookup_tables();
F::initialize_lookup_tables(small_value_lookup_tables);

let mut tasks = Vec::new();

const K: usize = 1 << 10;
const T: usize = 1 << 20;
const ZIPF_S: f64 = 0.0;
let zipf = Zipf::new(K as u64, ZIPF_S).unwrap();

let mut rng = test_rng();

let mut registers = [0u32; K];
let mut read_addresses: Vec<usize> = Vec::with_capacity(T);
let mut read_values: Vec<u32> = Vec::with_capacity(T);
let mut write_addresses: Vec<usize> = Vec::with_capacity(T);
let mut write_values: Vec<u32> = Vec::with_capacity(T);
let mut write_increments: Vec<i64> = Vec::with_capacity(T);
for _ in 0..T {
// Random read register
let read_address = zipf.sample(&mut rng) as usize - 1;
// Random write register
let write_address = zipf.sample(&mut rng) as usize - 1;
read_addresses.push(read_address);
write_addresses.push(write_address);
// Read the value currently in the read register
read_values.push(registers[read_address]);
// Random write value
let write_value = rng.next_u32();
write_values.push(write_value);
// The increment is the difference between the new value and the old value
let write_increment = (write_value as i64) - (registers[write_address] as i64);
write_increments.push(write_increment);
// Write the new value to the write register
registers[write_address] = write_value;
}

let mut prover_transcript = ProofTranscript::new(b"test_transcript");
let r: Vec<F> = prover_transcript.challenge_vector(K.log_2());
let r_prime: Vec<F> = prover_transcript.challenge_vector(T.log_2());

let task = move || {
let _proof = TwistProof::prove(
read_addresses,
read_values,
write_addresses,
write_values,
write_increments,
r.clone(),
r_prime.clone(),
&mut prover_transcript,
TwistAlgorithm::Local,
);
};

tasks.push((
tracing::info_span!("Twist d=1"),
Box::new(task) as Box<dyn FnOnce()>,
));

tasks
}

fn fibonacci<F, PCS, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
Expand Down
37 changes: 18 additions & 19 deletions jolt-core/src/poly/eq_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ impl<F: JoltField> EqPolynomial<F> {

#[tracing::instrument(skip_all, name = "EqPolynomial::evals")]
pub fn evals(r: &[F]) -> Vec<F> {
let ell = r.len();

match ell {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, ell, None),
_ => Self::evals_parallel(r, ell, None),
match r.len() {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, None),
_ => Self::evals_parallel(r, None),
}
}

Expand All @@ -45,19 +43,19 @@ impl<F: JoltField> EqPolynomial<F> {
/// the dynamic programming tree to R^2 instead of 1.
#[tracing::instrument(skip_all, name = "EqPolynomial::evals_with_r2")]
pub fn evals_with_r2(r: &[F]) -> Vec<F> {
let ell = r.len();

match ell {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, ell, F::montgomery_r2()),
_ => Self::evals_parallel(r, ell, F::montgomery_r2()),
match r.len() {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, F::montgomery_r2()),
_ => Self::evals_parallel(r, F::montgomery_r2()),
}
}

/// Computes evals serially. Uses less memory (and fewer allocations) than `evals_parallel`.
fn evals_serial(r: &[F], ell: usize, r2: Option<F>) -> Vec<F> {
let mut evals: Vec<F> = vec![r2.unwrap_or(F::one()); ell.pow2()];
/// Computes the table of coefficients:
/// scaling_factor * eq(r, x) for all x in {0, 1}^n
/// serially. More efficient for short `r`.
fn evals_serial(r: &[F], scaling_factor: Option<F>) -> Vec<F> {
let mut evals: Vec<F> = vec![scaling_factor.unwrap_or(F::one()); r.len().pow2()];
let mut size = 1;
for j in 0..ell {
for j in 0..r.len() {
// in each iteration, we double the size of chis
size *= 2;
for i in (0..size).rev().step_by(2) {
Expand All @@ -70,14 +68,15 @@ impl<F: JoltField> EqPolynomial<F> {
evals
}

/// Computes evals in parallel. Uses more memory and allocations than `evals_serial`, but
/// evaluates biggest layers of the dynamic programming tree in parallel.
/// Computes the table of coefficients:
/// scaling_factor * eq(r, x) for all x in {0, 1}^n
/// computing biggest layers of the dynamic programming tree in parallel.
#[tracing::instrument(skip_all, "EqPolynomial::evals_parallel")]
pub fn evals_parallel(r: &[F], ell: usize, r2: Option<F>) -> Vec<F> {
let final_size = (2usize).pow(ell as u32);
pub fn evals_parallel(r: &[F], scaling_factor: Option<F>) -> Vec<F> {
let final_size = r.len().pow2();
let mut evals: Vec<F> = unsafe_allocate_zero_vec(final_size);
let mut size = 1;
evals[0] = r2.unwrap_or(F::one());
evals[0] = scaling_factor.unwrap_or(F::one());

for r in r.iter().rev() {
let (evals_left, evals_right) = evals.split_at_mut(size);
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/subprotocols/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod grand_product_quarks;
pub mod shout;
pub mod sparse_grand_product;
pub mod sumcheck;
pub mod twist;

#[derive(Clone, Copy, Debug, Default)]
pub enum QuarkHybridLayerDepth {
Expand Down
Loading

0 comments on commit dae559d

Please sign in to comment.