From 8bd1fac8a8e1cd003e0c6d2fe99f588dadc03a63 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 8 Aug 2024 17:15:29 +0800 Subject: [PATCH 1/2] monomial expression to virtual poly --- ceno_zkvm/src/expression.rs | 103 ++++++++++++++++++++++++++++- ceno_zkvm/src/scheme/prover.rs | 1 + ceno_zkvm/src/virtual_polys.rs | 116 ++++++++++++++++++++++++++++++++- 3 files changed, 217 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 038b82d4c..232b92575 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -18,11 +18,18 @@ pub enum Expression { Sum(Box>, Box>), /// This is the product of two polynomials Product(Box>, Box>), - /// This is a ax + b polynomial + /// This is x, a, b expr to represent ax + b polynomial ScaledSum(Box>, Box>, Box>), Challenge(ChallengeId, usize, E, E), // (challenge_id, power, scalar, offset) } +/// this is used as finite state machine state +/// for differentiate a expression is in monomial form or not +enum MonomialState { + SumTerm, + ProductTerm, +} + impl Expression { pub fn degree(&self) -> usize { match self { @@ -69,6 +76,46 @@ impl Expression { } } } + + pub fn is_monomial_form(&self) -> bool { + Self::is_monomial_form_inner(MonomialState::SumTerm, self) + } + + fn is_zero_expr(expr: &Expression) -> bool { + match expr { + Expression::WitIn(_) => false, + Expression::Constant(c) => *c == E::BaseField::ZERO, + Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), + Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b), + Expression::ScaledSum(_, _, _) => false, + Expression::Challenge(_, _, _, _) => false, + } + } + fn is_monomial_form_inner(s: MonomialState, expr: &Expression) -> bool { + match (expr, s) { + (Expression::WitIn(_), MonomialState::SumTerm) => true, + (Expression::WitIn(_), MonomialState::ProductTerm) => true, + (Expression::Constant(_), MonomialState::SumTerm) => true, + (Expression::Constant(_), MonomialState::ProductTerm) => true, + (Expression::Sum(a, b), MonomialState::SumTerm) => { + Self::is_monomial_form_inner(MonomialState::SumTerm, a) + && Self::is_monomial_form_inner(MonomialState::SumTerm, b) + } + (Expression::Sum(_, _), MonomialState::ProductTerm) => false, + (Expression::Product(a, b), MonomialState::SumTerm) => { + Self::is_monomial_form_inner(MonomialState::ProductTerm, a) + && Self::is_monomial_form_inner(MonomialState::ProductTerm, b) + } + (Expression::Product(a, b), MonomialState::ProductTerm) => { + Self::is_monomial_form_inner(MonomialState::ProductTerm, a) + && Self::is_monomial_form_inner(MonomialState::ProductTerm, b) + } + (Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true, + (Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b), + (Expression::Challenge(_, _, _, _), MonomialState::SumTerm) => true, + (Expression::Challenge(_, _, _, _), MonomialState::ProductTerm) => true, + } + } } impl Neg for Expression { @@ -425,4 +472,58 @@ mod tests { ) ); } + + #[test] + fn test_is_monomial_form() { + type E = GoldilocksExt2; + let mut cb = CircuitBuilder::::new(); + let x = cb.create_witin(); + let y = cb.create_witin(); + let z = cb.create_witin(); + // scaledsum * challenge + // 3 * x + 2 + let expr: Expression = + Into::>::into(3usize) * x.expr() + Into::>::into(2usize); + assert_eq!(expr.is_monomial_form(), true); + + // 2 product term + let expr: Expression = Into::>::into(3usize) * x.expr() * y.expr() + + Into::>::into(2usize) * x.expr(); + assert_eq!(expr.is_monomial_form(), true); + + // complex linear operation + // (2c + 3) * x * y - 6z + let expr: Expression = + Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() + - Into::>::into(6usize) * z.expr(); + assert_eq!(expr.is_monomial_form(), true); + + // complex linear operation + // (2c + 3) * x * y - 6z + let expr: Expression = + Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr() + - Into::>::into(6usize) * z.expr(); + assert_eq!(expr.is_monomial_form(), true); + + // complex linear operation + // (2 * x + 3) * 3 + 6 * 8 + let expr: Expression = (Into::>::into(2usize) * x.expr() + + Into::>::into(3usize)) + * Into::>::into(3usize) + + Into::>::into(6usize) * Into::>::into(8usize); + assert_eq!(expr.is_monomial_form(), true); + } + + #[test] + fn test_not_monomial_form() { + type E = GoldilocksExt2; + let mut cb = CircuitBuilder::::new(); + let x = cb.create_witin(); + let y = cb.create_witin(); + // scaledsum * challenge + // (x + 1) * (y + 1) + let expr: Expression = (Into::>::into(1usize) + x.expr()) + * (Into::>::into(2usize) + y.expr()); + assert_eq!(expr.is_monomial_form(), false); + } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1f4c9a823..9d5525440 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -330,6 +330,7 @@ impl ZKVMProver { ); } } + let (main_sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index ee630bfdd..3b3da0b3a 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -1,11 +1,11 @@ -use std::sync::Arc; +use std::{collections::BTreeSet, mem, sync::Arc}; use ff_ext::ExtensionField; use gkr::util::ceil_log2; use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}; -use crate::structs::VirtualPolynomials; +use crate::{expression::Expression, structs::VirtualPolynomials}; impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { pub fn new(num_threads: usize, num_variables: usize) -> Self { @@ -57,4 +57,116 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { pub fn get_batched_polys(self) -> Vec> { self.polys } + + pub fn add_mle_list_by_expr( + &mut self, + // witin_id -> thread_id + wit_ins: Vec>>, + expr: &Expression, + challenges: &[E], + ) { + assert!(expr.is_monomial_form()); + let monomial_terms = expr.evaluate( + &|witness_id| { + vec![(E::ONE, { + let mut monomial_terms = BTreeSet::new(); + monomial_terms.insert(witness_id); + monomial_terms + })] + }, + &|scalar| vec![(E::from(scalar), { BTreeSet::new() })], + &|challenge_id, pow, scalar, offset| { + let challenge = challenges[challenge_id as usize]; + vec![( + challenge.pow(&[pow as u64]) * scalar + offset, + BTreeSet::new(), + )] + }, + &|mut a, b| { + a.extend(b); + a + }, + &|mut a, mut b| { + assert!(a.len() <= 2); + assert!(b.len() <= 2); + // special logic to deal with scaledsum + // scaledsum second parameter must be 0 + if a.len() == 2 { + assert!((a[1].0, a[1].1.is_empty()) == (E::ZERO, true)); + a.truncate(1); + } + if b.len() == 2 { + assert!((b[1].0, b[1].1.is_empty()) == (E::ZERO, true)); + b.truncate(1); + } + + a[0].1.extend(mem::take(&mut b[0].1)); + // return [ab] + vec![(a[0].0 * b[0].0, mem::take(&mut a[0].1))] + }, + &|mut x, a, b| { + assert!(a.len() == 1 && a[0].1.is_empty()); // for challenge or constant, term should be empty + assert!(b.len() == 1 && b[0].1.is_empty()); // for challenge or constant, term should be empty + assert!(x.len() == 1 && (x[0].0, x[0].1.len()) == (E::ONE, 1)); // witin size only 1 + if b[0].0 == E::ZERO { + // only include first term if b = 0 + vec![(a[0].0, mem::take(&mut x[0].1))] + } else { + // return [ax, b] + vec![(a[0].0, mem::take(&mut x[0].1)), (b[0].0, BTreeSet::new())] + } + }, + ); + for (constant, monomial_term) in monomial_terms.iter() { + if *constant != E::ZERO && monomial_term.is_empty() { + todo!("make virtual poly support pure constant") + } + for thread_id in 0..self.num_threads { + let terms_polys = monomial_term + .iter() + .map(|wit_id| wit_ins[*wit_id as usize][thread_id].clone()) + .collect_vec(); + + self.add_mle_list(thread_id, terms_polys, *constant); + } + } + } +} + +#[cfg(test)] +mod tests { + + use goldilocks::{Goldilocks, GoldilocksExt2}; + use multilinear_extensions::{mle::IntoMLE, virtual_poly_v2::ArcMultilinearExtension}; + + use crate::{ + circuit_builder::CircuitBuilder, + expression::{Expression, ToExpr}, + structs::VirtualPolynomials, + }; + + #[test] + fn test_add_mle_list_by_expr() { + type E = GoldilocksExt2; + let mut cb = CircuitBuilder::::new(); + let x = cb.create_witin(); + let y = cb.create_witin(); + + let wits_in: Vec> = (0..cb.num_witin as usize) + .map(|_| vec![Goldilocks::from(1)].into_mle().into()) + .collect(); + + let mut virtual_polys = VirtualPolynomials::new(1, 0); + let wits_threads: Vec>> = wits_in + .iter() + .map(|wit_poly| virtual_polys.get_all_range_polys(wit_poly)) + .collect(); + + // 3xy + 2y + let expr: Expression = + Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); + + println!("expr {:?}", expr); + virtual_polys.add_mle_list_by_expr(wits_threads, &expr, &[]); + } } From 57a2d323520cf713c7a545c70e4d3b0e60019639 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 8 Aug 2024 23:38:05 +0800 Subject: [PATCH 2/2] degree > 1 sumcheck batched with main constraint --- ceno_zkvm/benches/riscv_add.rs | 20 ++--- ceno_zkvm/src/chip_handler.rs | 2 +- ceno_zkvm/src/chip_handler/general.rs | 9 ++- ceno_zkvm/src/circuit_builder.rs | 5 ++ ceno_zkvm/src/instructions/riscv/add.rs | 20 ++--- ceno_zkvm/src/scheme/constants.rs | 2 + ceno_zkvm/src/scheme/prover.rs | 100 ++++++++++++++++++++---- ceno_zkvm/src/scheme/utils.rs | 15 +--- ceno_zkvm/src/scheme/verifier.rs | 63 ++++++++++++--- ceno_zkvm/src/virtual_polys.rs | 34 ++++++-- 10 files changed, 203 insertions(+), 67 deletions(-) diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9a22fed06..bbd7556f5 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -1,10 +1,7 @@ #![allow(clippy::manual_memcpy)] #![allow(clippy::needless_range_loop)] -use std::{ - collections::BTreeMap, - time::{Duration, Instant}, -}; +use std::time::{Duration, Instant}; use ark_std::test_rng; use ceno_zkvm::{ @@ -17,8 +14,8 @@ use criterion::*; use ff_ext::ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; +use itertools::Itertools; use multilinear_extensions::mle::IntoMLE; -use simple_frontend::structs::WitnessId; use transcript::Transcript; cfg_if::cfg_if! { @@ -93,17 +90,16 @@ fn bench_add(c: &mut Criterion) { }, |(mut rng, real_challenges)| { // generate mock witness - let mut wits_in = BTreeMap::new(); let num_instances = 1 << instance_num_vars; - (0..num_witin as usize).for_each(|witness_id| { - wits_in.insert( - witness_id as WitnessId, + let wits_in = (0..num_witin as usize) + .map(|_| { (0..num_instances) .map(|_| Goldilocks::random(&mut rng)) .collect::>() - .into_mle(), - ); - }); + .into_mle() + .into() + }) + .collect_vec(); let timer = Instant::now(); let _ = prover .create_proof( diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index c53f326aa..5ddab3662 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -2,7 +2,7 @@ use ff_ext::ExtensionField; use crate::{ error::ZKVMError, - expression::{Expression, WitIn}, + expression::WitIn, structs::{PCUInt, TSUInt, UInt64}, }; diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 7e49382ff..95ad15fce 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -18,6 +18,7 @@ impl CircuitBuilder { lk_expressions: vec![], assert_zero_expressions: vec![], assert_zero_sumcheck_expressions: vec![], + max_non_lc_degree: 0, chip_record_alpha: Expression::Challenge(0, 1, E::ONE, E::ZERO), chip_record_beta: Expression::Challenge(1, 1, E::ONE, E::ZERO), phantom: std::marker::PhantomData, @@ -96,7 +97,12 @@ impl CircuitBuilder { if assert_zero_expr.degree() == 1 { self.assert_zero_expressions.push(assert_zero_expr); } else { - // TODO check expression must be in multivariate monomial form + assert_eq!( + assert_zero_expr.is_monomial_form(), + true, + "only support sumcheck in monomial form" + ); + self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree()); self.assert_zero_sumcheck_expressions.push(assert_zero_expr); } Ok(()) @@ -132,6 +138,7 @@ impl CircuitBuilder { lk_expressions: self.lk_expressions.clone(), assert_zero_expressions: self.assert_zero_expressions.clone(), assert_zero_sumcheck_expressions: self.assert_zero_sumcheck_expressions.clone(), + max_non_lc_degree: self.max_non_lc_degree, } } } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 287460a66..89aad92e2 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -19,6 +19,8 @@ pub struct CircuitBuilder { pub assert_zero_expressions: Vec>, /// main constraints zero expression for expression degree > 1, which require sumcheck to prove pub assert_zero_sumcheck_expressions: Vec>, + /// max zero sumcheck degree + pub max_non_lc_degree: usize, // alpha, beta challenge for chip record pub chip_record_alpha: Expression, @@ -39,4 +41,7 @@ pub struct Circuit { pub assert_zero_expressions: Vec>, /// main constraints zero expression for expression degree > 1, which require sumcheck to prove pub assert_zero_sumcheck_expressions: Vec>, + + /// max zero sumcheck degree + pub max_non_lc_degree: usize, } diff --git a/ceno_zkvm/src/instructions/riscv/add.rs b/ceno_zkvm/src/instructions/riscv/add.rs index 7aa11ff2f..4262b1138 100644 --- a/ceno_zkvm/src/instructions/riscv/add.rs +++ b/ceno_zkvm/src/instructions/riscv/add.rs @@ -65,6 +65,10 @@ impl Instruction for AddInstruction { circuit_builder.assert_u5(rs2_id.expr())?; circuit_builder.assert_u5(rd_id.expr())?; + // TODO remove me, this is just for testing degree > 1 sumcheck in main constraints + circuit_builder + .require_zero(rs1_id.expr() * rs1_id.expr() - rs1_id.expr() * rs1_id.expr())?; + let mut prev_rs1_ts = TSUInt::new(circuit_builder); let mut prev_rs2_ts = TSUInt::new(circuit_builder); let mut prev_rd_ts = TSUInt::new(circuit_builder); @@ -106,15 +110,14 @@ impl Instruction for AddInstruction { #[cfg(test)] mod test { - use std::collections::BTreeMap; use ark_std::test_rng; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::PointAndEval; use goldilocks::{Goldilocks, GoldilocksExt2}; + use itertools::Itertools; use multilinear_extensions::mle::IntoMLE; - use simple_frontend::structs::WitnessId; use transcript::Transcript; use crate::{ @@ -134,17 +137,16 @@ mod test { let circuit = circuit_builder.finalize_circuit(); // generate mock witness - let mut wits_in = BTreeMap::new(); let num_instances = 1 << 2; - (0..circuit.num_witin as usize).for_each(|witness_id| { - wits_in.insert( - witness_id as WitnessId, + let wits_in = (0..circuit.num_witin as usize) + .map(|_| { (0..num_instances) .map(|_| Goldilocks::random(&mut rng)) .collect::>() - .into_mle(), - ); - }); + .into_mle() + .into() + }) + .collect_vec(); // get proof let prover = ZKVMProver::new(circuit.clone()); // circuit clone due to verifier alos need circuit reference diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 0cda4989e..93a86e660 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -1,3 +1,5 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; pub(crate) const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup +pub(crate) const SEL_DEGREE: usize = 2; + pub const NUM_FANIN: usize = 2; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 9d5525440..ed3c67c17 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,17 +1,14 @@ -use std::collections::BTreeMap; +use std::collections::BTreeSet; use ff_ext::ExtensionField; use gkr::{entered_span, exit_span, structs::Point}; use itertools::Itertools; use multilinear_extensions::{ - mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, - util::ceil_log2, - virtual_poly::build_eq_x_r_vec, + mle::IntoMLE, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -use simple_frontend::structs::WitnessId; use sumcheck::structs::{IOPProverMessage, IOPProverStateV2}; use transcript::Transcript; @@ -44,9 +41,9 @@ impl ZKVMProver { /// major flow break down into /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input - pub fn create_proof( + pub fn create_proof<'a>( &self, - witnesses: BTreeMap>, + witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, @@ -58,9 +55,9 @@ impl ZKVMProver { // sanity check assert_eq!(witnesses.len(), circuit.num_witin as usize); - witnesses.iter().all(|(_, v)| { + assert!(witnesses.iter().all(|v| { v.num_vars() == log2_num_instances && v.evaluations().len() == next_pow2_instances - }); + })); // main constraint: read/write record witness inference let span = entered_span!("wit_inference::record"); @@ -222,15 +219,24 @@ impl ZKVMProver { // batch sumcheck: selector + main degree > 1 constraints let span = entered_span!("sumcheck::main_sel"); - let (rt_r, rt_w, rt_lk): (Vec, Vec, Vec) = ( + let (rt_r, rt_w, rt_lk, rt_non_lc_sumcheck): (Vec, Vec, Vec, Vec) = ( rt_tower[..log2_num_instances + log2_r_count].to_vec(), rt_tower[..log2_num_instances + log2_w_count].to_vec(), rt_tower[..log2_num_instances + log2_lk_count].to_vec(), + rt_tower[..log2_num_instances].to_vec(), ); let num_threads = proper_num_threads(log2_num_instances, max_threads); - let alpha_pow = get_challenge_pows(MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, transcript); - let (alpha_read, alpha_write, alpha_lk) = (&alpha_pow[0], &alpha_pow[1], &alpha_pow[2]); + let alpha_pow = get_challenge_pows( + MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + circuit.assert_zero_sumcheck_expressions.len(), + transcript, + ); + let mut alpha_pow_iter = alpha_pow.iter(); + let (alpha_read, alpha_write, alpha_lk) = ( + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + ); // create selector: all ONE, but padding ZERO to ceil_log2 let (sel_r, sel_w, sel_lk): ( ArcMultilinearExtension, @@ -259,6 +265,25 @@ impl ZKVMProver { sel_lk.into_mle().into(), ) }; + + // only initialize when circuit got assert_zero_sumcheck_expressions + let sel_non_lc_zero_sumcheck = { + if !circuit.assert_zero_sumcheck_expressions.is_empty() { + let mut sel_non_lc_zero_sumcheck = build_eq_x_r_vec(&rt_non_lc_sumcheck); + if num_instances < sel_non_lc_zero_sumcheck.len() { + sel_non_lc_zero_sumcheck.splice( + num_instances..sel_non_lc_zero_sumcheck.len(), + std::iter::repeat(E::ZERO), + ); + } + let sel_non_lc_zero_sumcheck: ArcMultilinearExtension = + sel_non_lc_zero_sumcheck.into_mle().into(); + Some(sel_non_lc_zero_sumcheck) + } else { + None + } + }; + let mut virtual_polys = VirtualPolynomials::::new(num_threads, log2_num_instances); let sel_r_threads = virtual_polys.get_all_range_polys(&sel_r); let sel_w_threads: Vec> = @@ -331,6 +356,34 @@ impl ZKVMProver { } } + let mut distrinct_zerocheck_terms_set = BTreeSet::new(); + // degree > 1 zero expression sumcheck + if !circuit.assert_zero_sumcheck_expressions.is_empty() { + assert!(sel_non_lc_zero_sumcheck.is_some()); + let sel_non_lc_zero_sumcheck_threads: Vec> = + virtual_polys.get_all_range_polys(sel_non_lc_zero_sumcheck.as_ref().unwrap()); + + let witnesses_threads: Vec>> = witnesses + .iter() + .map(|wit_poly| virtual_polys.get_all_range_polys(wit_poly)) + .collect(); + + // \sum_t (sel(rt, t) * (\sum_j alpha_{j} * all_monomial_terms(t) )) + for (expr, alpha) in circuit + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(alpha_pow_iter) + { + distrinct_zerocheck_terms_set.extend(virtual_polys.add_mle_list_by_expr( + Some(sel_non_lc_zero_sumcheck_threads.clone()), + &witnesses_threads, + expr, + challenges, + *alpha, + )); + } + } + let (main_sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), @@ -339,7 +392,15 @@ impl ZKVMProver { let main_sel_evals = state.get_mle_final_evaluations(); assert_eq!( main_sel_evals.len(), - r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance + 3 + r_counts_per_instance + + w_counts_per_instance + + lk_counts_per_instance + + 3 + + if circuit.assert_zero_sumcheck_expressions.is_empty() { + 0 + } else { + distrinct_zerocheck_terms_set.len() + 1 // 1 from sel_non_lc_zero_sumcheck + } ); // 3 from [sel_r, sel_w, sel_lk] let mut main_sel_evals_iter = main_sel_evals.into_iter(); main_sel_evals_iter.next(); // skip sel_r @@ -354,7 +415,16 @@ impl ZKVMProver { let lk_records_in_evals = (0..lk_counts_per_instance) .map(|_| main_sel_evals_iter.next().unwrap()) .collect_vec(); - assert!(main_sel_evals_iter.next().is_none()); + assert!( + // we can skip all the rest of degree > 1 monomial terms because all the witness evaluation will be evaluated at last step + // and pass to verifier + main_sel_evals_iter.count() + == if circuit.assert_zero_sumcheck_expressions.is_empty() { + 0 + } else { + distrinct_zerocheck_terms_set.len() + 1 + } + ); let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); exit_span!(span); @@ -362,7 +432,7 @@ impl ZKVMProver { let span = entered_span!("witin::evals"); let wits_in_evals = witnesses .par_iter() - .map(|(_, poly)| poly.evaluate(&input_open_point)) + .map(|poly| poly.evaluate(&input_open_point)) .collect(); exit_span!(span); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index f4a12f511..bb9a98192 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::sync::Arc; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; @@ -17,7 +17,6 @@ use rayon::{ }, prelude::ParallelSliceMut, }; -use simple_frontend::structs::WitnessId; use crate::{expression::Expression, scheme::constants::MIN_PAR_SIZE}; @@ -202,20 +201,12 @@ pub(crate) fn infer_tower_product_witness<'a, E: ExtensionField>( } pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField>( - witnesses: &BTreeMap>, + witnesses: &[ArcMultilinearExtension<'a, E>], challenges: &[E], expr: &Expression, ) -> ArcMultilinearExtension<'a, E> { expr.evaluate::>( - &|witness_id| { - let a: ArcMultilinearExtension = Arc::new( - witnesses - .get(&witness_id) - .expect("non exist witness") - .clone(), - ); - a - }, + &|witness_id| witnesses[witness_id as usize].clone(), &|scalar| { let scalar: ArcMultilinearExtension = Arc::new( DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]), diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 17628be96..f00c15232 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -15,7 +15,10 @@ use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::Transcript; use crate::{ - circuit_builder::Circuit, error::ZKVMError, scheme::constants::NUM_FANIN, structs::TowerProofs, + circuit_builder::Circuit, + error::ZKVMError, + scheme::constants::{NUM_FANIN, SEL_DEGREE}, + structs::TowerProofs, utils::get_challenge_pows, }; @@ -97,8 +100,18 @@ impl ZKVMVerifier { rt_tower[..log2_num_instances + log2_lk_count].to_vec(), ); - let alpha_pow = get_challenge_pows(MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, transcript); - let (alpha_read, alpha_write, alpha_lk) = (&alpha_pow[0], &alpha_pow[1], &alpha_pow[2]); + let alpha_pow = get_challenge_pows( + MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + self.circuit.assert_zero_sumcheck_expressions.len(), + transcript, + ); + let mut alpha_pow_iter = alpha_pow.iter(); + let (alpha_read, alpha_write, alpha_lk) = ( + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + alpha_pow_iter.next().unwrap(), + ); + // alpha_read * (out_r[rt] - 1) + alpha_write * (out_w[rt] - 1) + alpha_lk * (out_lk_q) + // + 0 // 0 come from zero check let claim_sum = *alpha_read * (record_evals[0] - E::ONE) + *alpha_write * (record_evals[1] - E::ONE) + *alpha_lk * (logup_q_evals[0]); @@ -109,13 +122,13 @@ impl ZKVMVerifier { proofs: proof.main_sel_sumcheck_proofs.clone(), }, &VPAuxInfo { - max_degree: 2, + max_degree: SEL_DEGREE.max(self.circuit.max_non_lc_degree), num_variables: log2_num_instances, phantom: PhantomData, }, transcript, ); - let (main_sel_eval_point, expected_evaluation) = ( + let (input_opening_point, expected_evaluation) = ( main_sel_subclaim .point .iter() @@ -127,8 +140,8 @@ impl ZKVMVerifier { let eq_w = build_eq_x_r_vec_sequential(&rt_w[..log2_w_count]); let eq_lk = build_eq_x_r_vec_sequential(&rt_lk[..log2_lk_count]); - let (sel_r, sel_w, sel_lk) = { - // TODO optimize sel evaluation + let (sel_r, sel_w, sel_lk, sel_non_lc_zero_sumcheck) = { + // TODO make sel evaluation succint let mut sel = vec![E::BaseField::ONE; 1 << log2_num_instances]; if num_instances < sel.len() { sel.splice( @@ -138,14 +151,27 @@ impl ZKVMVerifier { } let sel = sel.into_mle(); ( - eq_eval(&rt_r[log2_r_count..], &main_sel_eval_point) + eq_eval(&rt_r[log2_r_count..], &input_opening_point) * sel.evaluate(&rt_r[log2_r_count..]), - eq_eval(&rt_w[log2_w_count..], &main_sel_eval_point) + eq_eval(&rt_w[log2_w_count..], &input_opening_point) * sel.evaluate(&rt_w[log2_w_count..]), - eq_eval(&rt_lk[log2_lk_count..], &main_sel_eval_point) + eq_eval(&rt_lk[log2_lk_count..], &input_opening_point) * sel.evaluate(&rt_lk[log2_lk_count..]), + // only initialize when circuit got non empty assert_zero_sumcheck_expressions + { + let rt_non_lc_sumcheck = rt_tower[..log2_num_instances].to_vec(); + if !self.circuit.assert_zero_sumcheck_expressions.is_empty() { + Some( + eq_eval(&rt_non_lc_sumcheck, &input_opening_point) + * sel.evaluate(&rt_non_lc_sumcheck), + ) + } else { + None + } + }, ) }; + let computed_evals = [ // read *alpha_read @@ -169,6 +195,21 @@ impl ZKVMVerifier { * ((0..lk_counts_per_instance) .map(|i| proof.lk_records_in_evals[i] * eq_lk[i]) .sum::()), + // degree > 1 zero exp sumcheck + { + // sel(rt_non_lc_sumcheck, main_sel_eval_point) * \sum_j (alpha{j} * expr(main_sel_eval_point)) + sel_non_lc_zero_sumcheck.unwrap_or(E::ZERO) + * self + .circuit + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(alpha_pow_iter) + .map(|(expr, alpha)| { + // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening + *alpha * eval_by_expr(&proof.wits_in_evals, challenges, &expr) + }) + .sum::() + }, ] .iter() .sum::(); @@ -197,8 +238,6 @@ impl ZKVMVerifier { return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); } - let input_opening_point = main_sel_eval_point; - // verify zero expression (degree = 1) statement, thus no sumcheck if self .circuit diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 3b3da0b3a..7386a5be3 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -58,13 +58,22 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { self.polys } + /// add mle terms into virtual poly by expression + /// return distinct witin in set pub fn add_mle_list_by_expr( &mut self, + // thread_based selector + selector: Option>>, // witin_id -> thread_id - wit_ins: Vec>>, + wit_ins: &[Vec>], expr: &Expression, challenges: &[E], - ) { + // sumcheck batch challenge + alpha: E, + ) -> BTreeSet { + if let Some(sel) = &selector { + assert_eq!(sel.len(), self.num_threads); + } assert!(expr.is_monomial_form()); let monomial_terms = expr.evaluate( &|witness_id| { @@ -122,14 +131,28 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { todo!("make virtual poly support pure constant") } for thread_id in 0..self.num_threads { + let sel = selector + .as_ref() + .map(|sel| vec![sel[thread_id].clone()]) + .unwrap_or(vec![]); let terms_polys = monomial_term .iter() .map(|wit_id| wit_ins[*wit_id as usize][thread_id].clone()) .collect_vec(); - self.add_mle_list(thread_id, terms_polys, *constant); + self.add_mle_list( + thread_id, + vec![sel, terms_polys].concat(), + *constant * alpha, + ); } } + + let num_distinct_witins = monomial_terms + .into_iter() + .flat_map(|(_, monomial_term)| monomial_term.into_iter().collect_vec()) + .collect::>(); + num_distinct_witins } } @@ -166,7 +189,8 @@ mod tests { let expr: Expression = Expression::from(3) * x.expr() * y.expr() + Expression::from(2) * y.expr(); - println!("expr {:?}", expr); - virtual_polys.add_mle_list_by_expr(wits_threads, &expr, &[]); + let distrinct_zerocheck_terms_set = + virtual_polys.add_mle_list_by_expr(None, &wits_threads, &expr, &[], 1.into()); + assert!(distrinct_zerocheck_terms_set.len() == 2); } }