Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

degree > 1 main constraint sumcheck implementation #107 #108

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#![allow(clippy::manual_memcpy)]
#![allow(clippy::needless_range_loop)]

use std::{
collections::BTreeMap,
time::{Duration, Instant},
};
use std::time::{Duration, Instant};

use ark_std::test_rng;
use ceno_zkvm::{
Expand All @@ -17,8 +14,8 @@ use criterion::*;

use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use simple_frontend::structs::WitnessId;
use transcript::Transcript;

cfg_if::cfg_if! {
Expand Down Expand Up @@ -93,17 +90,16 @@ fn bench_add(c: &mut Criterion) {
},
|(mut rng, real_challenges)| {
// generate mock witness
let mut wits_in = BTreeMap::new();
let num_instances = 1 << instance_num_vars;
(0..num_witin as usize).for_each(|witness_id| {
wits_in.insert(
witness_id as WitnessId,
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle(),
);
});
.into_mle()
.into()
})
.collect_vec();
let timer = Instant::now();
let _ = prover
.create_proof(
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ff_ext::ExtensionField;

use crate::{
error::ZKVMError,
expression::{Expression, WitIn},
expression::WitIn,
structs::{PCUInt, TSUInt, UInt64},
};

Expand Down
9 changes: 8 additions & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl<E: ExtensionField> CircuitBuilder<E> {
lk_expressions: vec![],
assert_zero_expressions: vec![],
assert_zero_sumcheck_expressions: vec![],
max_non_lc_degree: 0,
chip_record_alpha: Expression::Challenge(0, 1, E::ONE, E::ZERO),
chip_record_beta: Expression::Challenge(1, 1, E::ONE, E::ZERO),
phantom: std::marker::PhantomData,
Expand Down Expand Up @@ -96,7 +97,12 @@ impl<E: ExtensionField> CircuitBuilder<E> {
if assert_zero_expr.degree() == 1 {
self.assert_zero_expressions.push(assert_zero_expr);
} else {
// TODO check expression must be in multivariate monomial form
assert_eq!(
assert_zero_expr.is_monomial_form(),
true,
"only support sumcheck in monomial form"
);
self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree());
self.assert_zero_sumcheck_expressions.push(assert_zero_expr);
}
Ok(())
Expand Down Expand Up @@ -132,6 +138,7 @@ impl<E: ExtensionField> CircuitBuilder<E> {
lk_expressions: self.lk_expressions.clone(),
assert_zero_expressions: self.assert_zero_expressions.clone(),
assert_zero_sumcheck_expressions: self.assert_zero_sumcheck_expressions.clone(),
max_non_lc_degree: self.max_non_lc_degree,
}
}
}
5 changes: 5 additions & 0 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub struct CircuitBuilder<E: ExtensionField> {
pub assert_zero_expressions: Vec<Expression<E>>,
/// main constraints zero expression for expression degree > 1, which require sumcheck to prove
pub assert_zero_sumcheck_expressions: Vec<Expression<E>>,
/// max zero sumcheck degree
pub max_non_lc_degree: usize,

// alpha, beta challenge for chip record
pub chip_record_alpha: Expression<E>,
Expand All @@ -39,4 +41,7 @@ pub struct Circuit<E: ExtensionField> {
pub assert_zero_expressions: Vec<Expression<E>>,
/// main constraints zero expression for expression degree > 1, which require sumcheck to prove
pub assert_zero_sumcheck_expressions: Vec<Expression<E>>,

/// max zero sumcheck degree
pub max_non_lc_degree: usize,
}
103 changes: 102 additions & 1 deletion ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ pub enum Expression<E: ExtensionField> {
Sum(Box<Expression<E>>, Box<Expression<E>>),
/// This is the product of two polynomials
Product(Box<Expression<E>>, Box<Expression<E>>),
/// This is a ax + b polynomial
/// This is x, a, b expr to represent ax + b polynomial
ScaledSum(Box<Expression<E>>, Box<Expression<E>>, Box<Expression<E>>),
Challenge(ChallengeId, usize, E, E), // (challenge_id, power, scalar, offset)
}

/// this is used as finite state machine state
/// for differentiate a expression is in monomial form or not
enum MonomialState {
SumTerm,
ProductTerm,
}

impl<E: ExtensionField> Expression<E> {
pub fn degree(&self) -> usize {
match self {
Expand Down Expand Up @@ -69,6 +76,46 @@ impl<E: ExtensionField> Expression<E> {
}
}
}

pub fn is_monomial_form(&self) -> bool {
Self::is_monomial_form_inner(MonomialState::SumTerm, self)
}

fn is_zero_expr(expr: &Expression<E>) -> bool {
match expr {
Expression::WitIn(_) => false,
Expression::Constant(c) => *c == E::BaseField::ZERO,
Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b),
Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b),
Expression::ScaledSum(_, _, _) => false,
Expression::Challenge(_, _, _, _) => false,
}
}
fn is_monomial_form_inner(s: MonomialState, expr: &Expression<E>) -> bool {
match (expr, s) {
(Expression::WitIn(_), MonomialState::SumTerm) => true,
(Expression::WitIn(_), MonomialState::ProductTerm) => true,
(Expression::Constant(_), MonomialState::SumTerm) => true,
(Expression::Constant(_), MonomialState::ProductTerm) => true,
(Expression::Sum(a, b), MonomialState::SumTerm) => {
Self::is_monomial_form_inner(MonomialState::SumTerm, a)
&& Self::is_monomial_form_inner(MonomialState::SumTerm, b)
}
(Expression::Sum(_, _), MonomialState::ProductTerm) => false,
(Expression::Product(a, b), MonomialState::SumTerm) => {
Self::is_monomial_form_inner(MonomialState::ProductTerm, a)
&& Self::is_monomial_form_inner(MonomialState::ProductTerm, b)
}
(Expression::Product(a, b), MonomialState::ProductTerm) => {
Self::is_monomial_form_inner(MonomialState::ProductTerm, a)
&& Self::is_monomial_form_inner(MonomialState::ProductTerm, b)
}
(Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true,
(Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b),
(Expression::Challenge(_, _, _, _), MonomialState::SumTerm) => true,
(Expression::Challenge(_, _, _, _), MonomialState::ProductTerm) => true,
}
}
}

impl<E: ExtensionField> Neg for Expression<E> {
Expand Down Expand Up @@ -425,4 +472,58 @@ mod tests {
)
);
}

#[test]
fn test_is_monomial_form() {
type E = GoldilocksExt2;
let mut cb = CircuitBuilder::<E>::new();
let x = cb.create_witin();
let y = cb.create_witin();
let z = cb.create_witin();
// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Into::<Expression<E>>::into(3usize) * x.expr() + Into::<Expression<E>>::into(2usize);
assert_eq!(expr.is_monomial_form(), true);

// 2 product term
let expr: Expression<E> = Into::<Expression<E>>::into(3usize) * x.expr() * y.expr()
+ Into::<Expression<E>>::into(2usize) * x.expr();
assert_eq!(expr.is_monomial_form(), true);

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
assert_eq!(expr.is_monomial_form(), true);

// complex linear operation
// (2c + 3) * x * y - 6z
let expr: Expression<E> =
Expression::Challenge(0, 1, 2.into(), 3.into()) * x.expr() * y.expr()
- Into::<Expression<E>>::into(6usize) * z.expr();
assert_eq!(expr.is_monomial_form(), true);

// complex linear operation
// (2 * x + 3) * 3 + 6 * 8
let expr: Expression<E> = (Into::<Expression<E>>::into(2usize) * x.expr()
+ Into::<Expression<E>>::into(3usize))
* Into::<Expression<E>>::into(3usize)
+ Into::<Expression<E>>::into(6usize) * Into::<Expression<E>>::into(8usize);
assert_eq!(expr.is_monomial_form(), true);
}

#[test]
fn test_not_monomial_form() {
type E = GoldilocksExt2;
let mut cb = CircuitBuilder::<E>::new();
let x = cb.create_witin();
let y = cb.create_witin();
// scaledsum * challenge
// (x + 1) * (y + 1)
let expr: Expression<E> = (Into::<Expression<E>>::into(1usize) + x.expr())
* (Into::<Expression<E>>::into(2usize) + y.expr());
assert_eq!(expr.is_monomial_form(), false);
}
}
20 changes: 11 additions & 9 deletions ceno_zkvm/src/instructions/riscv/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
circuit_builder.assert_u5(rs2_id.expr())?;
circuit_builder.assert_u5(rd_id.expr())?;

// TODO remove me, this is just for testing degree > 1 sumcheck in main constraints
circuit_builder
.require_zero(rs1_id.expr() * rs1_id.expr() - rs1_id.expr() * rs1_id.expr())?;

let mut prev_rs1_ts = TSUInt::new(circuit_builder);
let mut prev_rs2_ts = TSUInt::new(circuit_builder);
let mut prev_rd_ts = TSUInt::new(circuit_builder);
Expand Down Expand Up @@ -106,15 +110,14 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use ark_std::test_rng;
use ff::Field;
use ff_ext::ExtensionField;
use gkr::structs::PointAndEval;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use simple_frontend::structs::WitnessId;
use transcript::Transcript;

use crate::{
Expand All @@ -134,17 +137,16 @@ mod test {
let circuit = circuit_builder.finalize_circuit();

// generate mock witness
let mut wits_in = BTreeMap::new();
let num_instances = 1 << 2;
(0..circuit.num_witin as usize).for_each(|witness_id| {
wits_in.insert(
witness_id as WitnessId,
let wits_in = (0..circuit.num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle(),
);
});
.into_mle()
.into()
})
.collect_vec();

// get proof
let prover = ZKVMProver::new(circuit.clone()); // circuit clone due to verifier alos need circuit reference
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/scheme/constants.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub(crate) const MIN_PAR_SIZE: usize = 64;
pub(crate) const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup
pub(crate) const SEL_DEGREE: usize = 2;

pub const NUM_FANIN: usize = 2;
Loading
Loading