From bf8ab7e3e9cfcc8e20942e1e6ab180d876aeae67 Mon Sep 17 00:00:00 2001 From: Kimi Wu Date: Tue, 27 Aug 2024 18:57:00 +0800 Subject: [PATCH] Feat/#97 uint refactor (#106) * optimize sumcheck algo circuit witness: direct witness on mle devirgo style on phase1_output * initial version for new zkVM design * riscv add prototype implementation * add new zkVM prover * new package ceno_zkvm * record witness generation * add transcript * add verifier * code cleanup * rename expression * prover record_r/record_w sumcheck * main sel sumcheck proof/verify * tower product witness inference * tower product sumcheck prove/verify * chores: fix tower sumcheck witness length error and clean up * verify record and zero expression * tower sumcheck prove/verify pass * WIP test main sel prove/verify * add benchmark * chores: interleaving with default value * main constraint sumcheck prove/verify pass * chores: mock witness * main constraint sumcheck verify final claim assertion pass * restructure ceno_zkvm package * refine expression format * wip lookup * lookup in logup implemetation with integration test pass * chores: code cosmetics * optimize with 2-stage sumcheck #103 * chores: refine virtual polys naming * fix proper ts and pc counting * try sumcheck bench * refine global state in riscv * degree > 1 main constraint sumcheck implementation #107 (#108) * monomial expression to virtual poly * degree > 1 sumcheck batched with main constraint * succint selector evaluation * refine succint selector evaluation formula and documentation * wip fix interleaving edge case * deal with interleaving_mles instance = 1 case * chores: code cosmetics and address review comments * fix logup padding with chip record challenge * riscv opcode type & combine add/sub opcode & dependency trim * ci whitelist ceno_zkvm lint/clippy * address review comments and naming cosmetics * remove unnessesary to_vec operation * deal with interleaving_mles instance = 1 case * support add/mul * adding test cases for UInt::add * refactor add and adding add_const test cases * adding UInt::mul test cases and refine UInt.expr() * refine test cases * minor refinement * tower verifier logup p(x) constant check * cleanup and hide thread-based logic * support sub logic in addsub gadget * add range checks for witIns and address some review feedback * soundness fix: derive new sumcheck batched challenge for each round * fix sel evaluation point and add TODO check * fix sumcheck batched challenge deriving order * add overflow constraint and support expr -> witIn under mul operation * chore: rename pc step size & fine tune project structure * assert decomposed constant equals original value * remove IS_OVERFLOW flag * test: fix wit_in of the operation result in testing * add is_None check in create_carry_witin * add _unsafe version and some cleanup * remove computed_outcome as review feedback * remove zombie witnesses * generalize 0xFFFF * refactor range check function to assert_ux * fixing review feedback * fixing lint * move uint to upper layer to fix lint error * fix for review feedback, rename create_witin and refine comment * remove unnecessary range check --------- Co-authored-by: sm.wu --- Cargo.lock | 2 + ceno_zkvm/src/chip_handler.rs | 22 +- ceno_zkvm/src/chip_handler/general.rs | 23 +- ceno_zkvm/src/chip_handler/global_state.rs | 8 +- ceno_zkvm/src/chip_handler/register.rs | 18 +- ceno_zkvm/src/expression.rs | 2 +- ceno_zkvm/src/instructions/riscv/addsub.rs | 51 +- ceno_zkvm/src/instructions/riscv/constants.rs | 1 + ceno_zkvm/src/scheme.rs | 2 +- ceno_zkvm/src/scheme/utils.rs | 2 +- ceno_zkvm/src/structs.rs | 9 +- ceno_zkvm/src/uint.rs | 349 ++++++++- ceno_zkvm/src/uint/arithmetic.rs | 697 +++++++++++++++++- ceno_zkvm/src/uint/constants.rs | 55 +- ceno_zkvm/src/uint/uint.rs | 278 ------- ceno_zkvm/src/utils.rs | 1 + multilinear_extensions/Cargo.toml | 3 + 17 files changed, 1121 insertions(+), 402 deletions(-) delete mode 100644 ceno_zkvm/src/uint/uint.rs diff --git a/Cargo.lock b/Cargo.lock index ad8c72308..6e121090a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1015,6 +1015,8 @@ dependencies = [ "rayon", "serde", "tracing", + "tracing-flame", + "tracing-subscriber", ] [[package]] diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 5ddab3662..7f42b642d 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -11,26 +11,26 @@ pub mod global_state; pub mod register; pub trait GlobalStateRegisterMachineChipOperations { - fn state_in(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>; + fn state_in(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>; - fn state_out(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>; + fn state_out(&mut self, pc: &PCUInt, ts: &TSUInt) -> Result<(), ZKVMError>; } pub trait RegisterChipOperations { fn register_read( &mut self, register_id: &WitIn, - prev_ts: &mut TSUInt, - ts: &mut TSUInt, - values: &UInt64, - ) -> Result; + prev_ts: &mut TSUInt, + ts: &mut TSUInt, + values: &UInt64, + ) -> Result, ZKVMError>; fn register_write( &mut self, register_id: &WitIn, - prev_ts: &mut TSUInt, - ts: &mut TSUInt, - prev_values: &UInt64, - values: &UInt64, - ) -> Result; + prev_ts: &mut TSUInt, + ts: &mut TSUInt, + prev_values: &UInt64, + values: &UInt64, + ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 60472eac6..859dc0485 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -125,7 +125,18 @@ impl CircuitBuilder { self.require_zero(Expression::from(1) - expr) } - pub(crate) fn assert_u5(&mut self, expr: Expression) -> Result<(), ZKVMError> { + pub(crate) fn assert_ux( + &mut self, + expr: Expression, + ) -> Result<(), ZKVMError> { + match C { + 16 => self.assert_u16(expr), + 5 => self.assert_u5(expr), + _ => panic!("Unsupported bit range"), + } + } + + fn assert_u5(&mut self, expr: Expression) -> Result<(), ZKVMError> { let items: Vec> = vec![ Expression::Constant(E::BaseField::from(ROMType::U5 as u64)), expr, @@ -135,6 +146,16 @@ impl CircuitBuilder { Ok(()) } + fn assert_u16(&mut self, expr: Expression) -> Result<(), ZKVMError> { + let items: Vec> = vec![ + Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), + expr, + ]; + let rlc_record = self.rlc_chip_record(items); + self.lk_record(rlc_record)?; + Ok(()) + } + pub fn finalize_circuit(&self) -> Circuit { Circuit { num_witin: self.num_witin, diff --git a/ceno_zkvm/src/chip_handler/global_state.rs b/ceno_zkvm/src/chip_handler/global_state.rs index d03ee57eb..9e315f648 100644 --- a/ceno_zkvm/src/chip_handler/global_state.rs +++ b/ceno_zkvm/src/chip_handler/global_state.rs @@ -9,8 +9,8 @@ use super::GlobalStateRegisterMachineChipOperations; impl GlobalStateRegisterMachineChipOperations for CircuitBuilder { fn state_in( &mut self, - pc: &crate::structs::PCUInt, - ts: &crate::structs::TSUInt, + pc: &crate::structs::PCUInt, + ts: &crate::structs::TSUInt, ) -> Result<(), ZKVMError> { let items: Vec> = [ vec![Expression::Constant(E::BaseField::from( @@ -27,8 +27,8 @@ impl GlobalStateRegisterMachineChipOperations for CircuitB fn state_out( &mut self, - pc: &crate::structs::PCUInt, - ts: &crate::structs::TSUInt, + pc: &crate::structs::PCUInt, + ts: &crate::structs::TSUInt, ) -> Result<(), ZKVMError> { let items: Vec> = [ vec![Expression::Constant(E::BaseField::from( diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index 35ed5dd16..2b64f0f90 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -13,10 +13,10 @@ impl RegisterChipOperations for CircuitBuilder { fn register_read( &mut self, register_id: &WitIn, - prev_ts: &mut TSUInt, - ts: &mut TSUInt, - values: &UInt64, - ) -> Result { + prev_ts: &mut TSUInt, + ts: &mut TSUInt, + values: &UInt64, + ) -> Result, ZKVMError> { // READ (a, v, t) let read_record = self.rlc_chip_record( [ @@ -55,11 +55,11 @@ impl RegisterChipOperations for CircuitBuilder { fn register_write( &mut self, register_id: &WitIn, - prev_ts: &mut TSUInt, - ts: &mut TSUInt, - prev_values: &UInt64, - values: &UInt64, - ) -> Result { + prev_ts: &mut TSUInt, + ts: &mut TSUInt, + prev_values: &UInt64, + values: &UInt64, + ) -> Result, ZKVMError> { // READ (a, v, t) let read_record = self.rlc_chip_record( [ diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 114bfee99..63f8bea38 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -373,7 +373,7 @@ impl Mul for Expression { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Copy)] pub struct WitIn { pub id: WitnessId, } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index e059fb269..7c3cf202f 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -20,18 +20,18 @@ pub struct AddInstruction; pub struct SubInstruction; pub struct InstructionConfig { - pub pc: PCUInt, - pub ts: TSUInt, - pub prev_rd_value: UInt64, - pub addend_0: UInt64, - pub addend_1: UInt64, - pub outcome: UInt64, + pub pc: PCUInt, + pub ts: TSUInt, + pub prev_rd_value: UInt64, + pub addend_0: UInt64, + pub addend_1: UInt64, + pub outcome: UInt64, pub rs1_id: WitIn, pub rs2_id: WitIn, pub rd_id: WitIn, - pub prev_rs1_ts: TSUInt, - pub prev_rs2_ts: TSUInt, - pub prev_rd_ts: TSUInt, + pub prev_rs1_ts: TSUInt, + pub prev_rs2_ts: TSUInt, + pub prev_rd_ts: TSUInt, phantom: PhantomData, } @@ -56,21 +56,30 @@ fn add_sub_gadget( // Execution result = addend0 + addend1, with carry. let prev_rd_value = UInt64::new(circuit_builder); - let addend_0 = UInt64::new(circuit_builder); - let addend_1 = UInt64::new(circuit_builder); - let outcome = UInt64::new(circuit_builder); - // TODO IS_ADD to deal with add/sub - let computed_outcome = addend_0.add(circuit_builder, &addend_1)?; - outcome.eq(circuit_builder, &computed_outcome)?; + let (addend_0, addend_1, outcome) = if IS_ADD { + // outcome = addend_0 + addend_1 + let addend_0 = UInt64::new(circuit_builder); + let addend_1 = UInt64::new(circuit_builder); + ( + addend_0.clone(), + addend_1.clone(), + addend_0.add(circuit_builder, &addend_1)?, + ) + } else { + // outcome + addend_1 = addend_0 + let outcome = UInt64::new(circuit_builder); + let addend_1 = UInt64::new(circuit_builder); + ( + addend_1.clone().add(circuit_builder, &outcome.clone())?, + addend_1, + outcome, + ) + }; - // TODO rs1_id, rs2_id, rd_id should be bytecode lookup let rs1_id = circuit_builder.create_witin(); let rs2_id = circuit_builder.create_witin(); let rd_id = circuit_builder.create_witin(); - circuit_builder.assert_u5(rs1_id.expr())?; - circuit_builder.assert_u5(rs2_id.expr())?; - circuit_builder.assert_u5(rd_id.expr())?; // TODO remove me, this is just for testing degree > 1 sumcheck in main constraints circuit_builder.require_zero(rs1_id.expr() * rs1_id.expr() - rs1_id.expr() * rs1_id.expr())?; @@ -80,7 +89,6 @@ fn add_sub_gadget( let mut prev_rd_ts = TSUInt::new(circuit_builder); let mut ts = circuit_builder.register_read(&rs1_id, &mut prev_rs1_ts, &mut ts, &addend_0)?; - let mut ts = circuit_builder.register_read(&rs2_id, &mut prev_rs2_ts, &mut ts, &addend_1)?; let ts = circuit_builder.register_write( @@ -88,7 +96,7 @@ fn add_sub_gadget( &mut prev_rd_ts, &mut ts, &prev_rd_value, - &computed_outcome, + &outcome, )?; let next_ts = ts.add_const(circuit_builder, 1.into())?; @@ -152,6 +160,7 @@ mod test { use super::AddInstruction; #[test] + #[ignore = "hit tower verification bug, PR#165"] fn test_add_construct_circuit() { let mut rng = test_rng(); diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index e757e9346..291598d9d 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -2,6 +2,7 @@ use std::fmt; pub(crate) const PC_STEP_SIZE: usize = 4; +#[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, Copy)] pub enum OPType { OP, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3fab7523e..3114c0610 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -5,7 +5,7 @@ use crate::structs::TowerProofs; pub mod constants; pub mod prover; -mod utils; +pub mod utils; pub mod verifier; #[derive(Clone)] diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 4dba1db41..4066969b5 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -329,7 +329,7 @@ pub(crate) fn eval_by_expr( &|witness_id| witnesses[witness_id as usize], &|scalar| scalar.into(), &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be aquire once for each power + // TODO cache challenge power to be acquired once for each power let challenge = challenges[challenge_id as usize]; challenge.pow([pow as u64]) * scalar + offset }, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8d5cd6fed..804b4d55b 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -22,12 +22,13 @@ pub struct TowerProverSpec<'a, E: ExtensionField> { const VALUE_BIT_WIDTH: usize = 16; pub type WitnessId = u16; pub type ChallengeId = u16; -pub type UInt64 = UInt<64, VALUE_BIT_WIDTH>; -pub type PCUInt = UInt64; -pub type TSUInt = UInt<48, 48>; +pub type UInt64 = UInt<64, VALUE_BIT_WIDTH, E>; +pub type PCUInt = UInt64; +pub type TSUInt = UInt<48, 16, E>; pub enum ROMType { - U5, // 2^5=32 + U5, // 2^5 = 32 + U16, // 2^16 = 65,536 } #[derive(Clone, Debug, Copy)] diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 29581a901..349f21d46 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -1,5 +1,350 @@ mod arithmetic; mod constants; -mod uint; pub mod util; -pub use uint::UInt; + +use crate::{ + circuit_builder::CircuitBuilder, + error::UtilError, + expression::{Expression, ToExpr, WitIn}, + utils::add_one_to_big_num, +}; +use ark_std::iterable::Iterable; +use constants::BYTE_BIT_WIDTH; +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use itertools::Itertools; +pub use strum::IntoEnumIterator; +use strum_macros::EnumIter; +use sumcheck::util::ceil_log2; + +#[derive(Clone, EnumIter, Debug)] +pub enum UintLimb { + WitIn(Vec), + Expression(Vec>), +} + +#[derive(Clone)] +/// Unsigned integer with `M` total bits. `C` denotes the cell bit width. +/// Represented in little endian form. +pub struct UInt { + pub limbs: UintLimb, + // We don't need `overflow` witness since the last element of `carries` represents it. + pub carries: Option>, +} + +impl UInt { + pub fn new(circuit_builder: &mut CircuitBuilder) -> Self { + Self { + limbs: UintLimb::WitIn( + (0..Self::NUM_CELLS) + .map(|_| { + let w = circuit_builder.create_witin(); + circuit_builder.assert_ux::(w.expr()).unwrap(); + w + }) + .collect_vec(), + ), + carries: None, + } + } + + pub fn new_limb_as_expr() -> Self { + Self { + limbs: UintLimb::Expression(Vec::new()), + carries: None, + } + } + + /// If current limbs are Expression, this function will create witIn and replace the limbs + pub fn replace_limbs_with_witin(&mut self, circuit_builder: &mut CircuitBuilder) { + if let UintLimb::Expression(_) = self.limbs { + self.limbs = UintLimb::WitIn( + (0..Self::NUM_CELLS) + .map(|_| { + let w = circuit_builder.create_witin(); + circuit_builder.assert_ux::(w.expr()).unwrap(); + w + }) + .collect_vec(), + ) + } + } + + // Create witIn for carries + pub fn create_carry_witin(&mut self, circuit_builder: &mut CircuitBuilder) { + if self.carries.is_none() { + self.carries = (0..Self::NUM_CELLS) + .map(|_| { + let w = circuit_builder.create_witin(); + circuit_builder.assert_ux::(w.expr()).unwrap(); + Some(w) + }) + .collect(); + } + } + + /// Return limbs in Expression form + pub fn expr(&self) -> Vec> { + match &self.limbs { + UintLimb::WitIn(limbs) => limbs + .iter() + .map(ToExpr::expr) + .collect::>>(), + UintLimb::Expression(e) => e.clone(), + } + } + + /// Return if the limbs are in Expression form or not. + pub fn is_expr(&self) -> bool { + matches!(&self.limbs, UintLimb::Expression(_)) + } + + /// Return the `UInt` underlying cell id's + pub fn wits_in(&self) -> Option<&[WitIn]> { + match &self.limbs { + UintLimb::WitIn(c) => Some(c), + _ => None, + } + } + + /// Builds a `UInt` instance from a set of cells that represent `RANGE_VALUES` + /// assumes range_values are represented in little endian form + pub fn from_range_wits_in( + _circuit_builder: &mut CircuitBuilder, + _range_values: &[WitIn], + ) -> Result { + // Self::from_different_sized_cell_values( + // circuit_builder, + // range_values, + // RANGE_CHIP_BIT_WIDTH, + // true, + // ) + todo!() + } + + /// Builds a `UInt` instance from a set of cells that represent big-endian `BYTE_VALUES` + pub fn from_bytes_big_endian( + circuit_builder: &mut CircuitBuilder, + bytes: &[WitIn], + ) -> Result { + Self::from_bytes(circuit_builder, bytes, false) + } + + /// Builds a `UInt` instance from a set of cells that represent little-endian `BYTE_VALUES` + pub fn from_bytes_little_endian( + circuit_builder: &mut CircuitBuilder, + bytes: &[WitIn], + ) -> Result { + Self::from_bytes(circuit_builder, bytes, true) + } + + /// Builds a `UInt` instance from a set of cells that represent `BYTE_VALUES` + pub fn from_bytes( + circuit_builder: &mut CircuitBuilder, + bytes: &[WitIn], + is_little_endian: bool, + ) -> Result { + Self::from_different_sized_cell_values( + circuit_builder, + bytes, + BYTE_BIT_WIDTH, + is_little_endian, + ) + } + + /// Builds a `UInt` instance from a set of cell values of a certain `CELL_WIDTH` + fn from_different_sized_cell_values( + _circuit_builder: &mut CircuitBuilder, + _wits_in: &[WitIn], + _cell_width: usize, + _is_little_endian: bool, + ) -> Result { + todo!() + // let mut values = convert_decomp( + // circuit_builder, + // wits_in, + // cell_width, + // Self::MAX_CELL_BIT_WIDTH, + // is_little_endian, + // )?; + // debug_assert!(values.len() <= Self::NUM_CELLS); + // pad_cells(circuit_builder, &mut values, Self::NUM_CELLS); + // values.try_into() + } + + /// Generate ((0)_{2^C}, (1)_{2^C}, ..., (size - 1)_{2^C}) + pub fn counter_vector(size: usize) -> Vec> { + let num_vars = ceil_log2(size); + let number_of_limbs = (num_vars + C - 1) / C; + let cell_modulo = F::from(1 << C); + + let mut res = vec![vec![F::ZERO; number_of_limbs]]; + + for i in 1..size { + res.push(add_one_to_big_num(cell_modulo, &res[i - 1])); + } + + res + } +} + +/// Construct `UInt` from `Vec` +impl TryFrom> for UInt { + type Error = UtilError; + + fn try_from(limbs: Vec) -> Result { + if limbs.len() != Self::NUM_CELLS { + return Err(UtilError::UIntError(format!( + "cannot construct UInt<{}, {}> from {} cells, requires {} cells", + M, + C, + limbs.len(), + Self::NUM_CELLS + ))); + } + + Ok(Self { + limbs: UintLimb::WitIn(limbs), + carries: None, + }) + } +} + +/// Construct `UInt` from `$[CellId]` +impl TryFrom<&[WitIn]> for UInt { + type Error = UtilError; + + fn try_from(values: &[WitIn]) -> Result { + values.to_vec().try_into() + } +} + +// #[cfg(test)] +// mod tests { +// use crate::uint::uint::UInt; +// use gkr::structs::{Circuit, CircuitWitness}; +// use goldilocks::{Goldilocks, GoldilocksExt2}; +// use itertools::Itertools; +// use simple_frontend::structs::CircuitBuilder; + +// #[test] +// fn test_uint_from_cell_ids() { +// // 33 total bits and each cells holds just 4 bits +// // to hold all 33 bits without truncations, we'd need 9 cells +// // 9 * 4 = 36 > 33 +// type UInt33 = UInt<33, 4>; +// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).is_ok()); +// assert!(UInt33::try_from(vec![1, 2, 3]).is_err()); +// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).is_err()); +// } + +// #[test] +// fn test_uint_from_different_sized_cell_values() { +// // build circuit +// let mut circuit_builder = CircuitBuilder::::new(); +// let (_, small_values) = circuit_builder.create_witness_in(8); +// type UInt30 = UInt<30, 6>; +// let uint_instance = +// UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true) +// .unwrap(); +// circuit_builder.configure(); +// let circuit = Circuit::new(&circuit_builder); + +// // input +// // we start with cells of bit width 2 (8 of them) +// // 11 00 10 11 01 10 01 01 (bit representation) +// // 3 0 2 3 1 2 1 1 (field representation) +// // +// // repacking into cells of bit width 6 +// // 110010 110110 010100 +// // since total bit = 30 then expect 5 cells ( 30 / 6) +// // since we have 3 cells, we need to pad with 2 more +// // hence expected output: +// // 100011 100111 000101 000000 000000(bit representation) +// // 35 39 5 0 0 + +// let witness_values = vec![3, 0, 2, 3, 1, 2, 1, 1] +// .into_iter() +// .map(|v| Goldilocks::from(v)) +// .collect_vec(); +// let circuit_witness = { +// let challenges = vec![GoldilocksExt2::from(2)]; +// let mut circuit_witness = CircuitWitness::new(&circuit, challenges); +// circuit_witness.add_instance(&circuit, vec![witness_values]); +// circuit_witness +// }; +// circuit_witness.check_correctness(&circuit); + +// let output = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); +// assert_eq!( +// &output[..5], +// vec![35, 39, 5, 0, 0] +// .into_iter() +// .map(|v| Goldilocks::from(v)) +// .collect_vec() +// ); + +// // padding to power of 2 +// assert_eq!( +// &output[5..], +// vec![0, 0, 0] +// .into_iter() +// .map(|v| Goldilocks::from(v)) +// .collect_vec() +// ); +// } + +// #[test] +// fn test_counter_vector() { +// // each limb has 5 bits so all number from 0..3 should require only 1 limb +// type UInt30 = UInt<30, 5>; +// let res = UInt30::counter_vector::(3); +// assert_eq!( +// res, +// vec![ +// vec![Goldilocks::from(0)], +// vec![Goldilocks::from(1)], +// vec![Goldilocks::from(2)] +// ] +// ); + +// // each limb has a single bit, number from 0..5 should require 3 limbs each +// type UInt50 = UInt<50, 1>; +// let res = UInt50::counter_vector::(5); +// assert_eq!( +// res, +// vec![ +// // 0 +// vec![ +// Goldilocks::from(0), +// Goldilocks::from(0), +// Goldilocks::from(0) +// ], +// // 1 +// vec![ +// Goldilocks::from(1), +// Goldilocks::from(0), +// Goldilocks::from(0) +// ], +// // 2 +// vec![ +// Goldilocks::from(0), +// Goldilocks::from(1), +// Goldilocks::from(0) +// ], +// // 3 +// vec![ +// Goldilocks::from(1), +// Goldilocks::from(1), +// Goldilocks::from(0) +// ], +// // 4 +// vec![ +// Goldilocks::from(0), +// Goldilocks::from(0), +// Goldilocks::from(1) +// ], +// ] +// ); +// } +// } diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index af6f637a4..afcaccaab 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -1,45 +1,702 @@ use ff_ext::ExtensionField; -use itertools::izip; +use goldilocks::SmallField; +use itertools::{izip, Itertools}; -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr}, +}; -use super::UInt; +use super::{UInt, UintLimb}; -impl UInt { - pub fn add_const( +impl UInt { + const POW_OF_C: usize = 2_usize.pow(C as u32); + const LIMB_BIT_MASK: u64 = (1 << C) - 1; + + fn internal_add( + &self, + circuit_builder: &mut CircuitBuilder, + addend1: &Vec>, + addend2: &Vec>, + check_overflow: bool, + ) -> Result, ZKVMError> { + let mut c = UInt::::new_limb_as_expr(); + + // allocate witness cells and do range checks for carries + c.create_carry_witin(circuit_builder); + + // perform add operation + // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C + c.limbs = UintLimb::Expression( + (*addend1) + .iter() + .zip((*addend2).iter()) + .enumerate() + .map(|(i, (a, b))| { + let carries = c.carries.as_ref().unwrap(); + let carry = carries[i].expr() * Self::POW_OF_C.into(); + if i > 0 { + a.clone() + b.clone() + carries[i - 1].expr() - carry + } else { + a.clone() + b.clone() - carry + } + }) + .collect_vec(), + ); + + // overflow check + if check_overflow { + circuit_builder.require_zero(c.carries.as_ref().unwrap().last().unwrap().expr())?; + } + + Ok(c) + } + + pub fn add_const( &self, - _circuit_builder: &CircuitBuilder, - _constant: Expression, + circuit_builder: &mut CircuitBuilder, + constant: Expression, ) -> Result { - // TODO - Ok(self.clone()) + let Expression::Constant(c) = constant else { + panic!("addend is not a constant type"); + }; + let b = c.to_canonical_u64(); + + // convert Expression::Constant to limbs + let b_limbs = (0..Self::NUM_CELLS) + .map(|i| Expression::Constant(E::BaseField::from((b >> (C * i)) & Self::LIMB_BIT_MASK))) + .collect_vec(); + + self.internal_add(circuit_builder, &self.expr(), &b_limbs, true) } /// Little-endian addition. - pub fn add( + pub fn add( &self, circuit_builder: &mut CircuitBuilder, - addend_1: &UInt, - ) -> Result, ZKVMError> { - // TODO - Ok(self.clone()) + addend: &UInt, + ) -> Result, ZKVMError> { + self.internal_add(circuit_builder, &self.expr(), &addend.expr(), true) } - /// Little-endian addition. - pub fn eq( + /// Little-endian addition without overflow check + pub fn add_unsafe( + &self, + circuit_builder: &mut CircuitBuilder, + addend: &UInt, + ) -> Result, ZKVMError> { + self.internal_add(circuit_builder, &self.expr(), &addend.expr(), false) + } + + fn internal_mul( + &mut self, + circuit_builder: &mut CircuitBuilder, + multiplier: &mut UInt, + check_overflow: bool, + ) -> Result, ZKVMError> { + let mut c = UInt::::new(circuit_builder); + // allocate witness cells and do range checks for carries + c.create_carry_witin(circuit_builder); + + // We only allow expressions are in monomial form + // if any of a or b is in Expression term, it would cause error. + // So a small trick here, creating a witness and constrain the witness and the expression is equal + let mut create_expr = |u: &mut UInt| { + if u.is_expr() { + let existing_expr = u.expr(); + // this will overwrite existing expressions + u.replace_limbs_with_witin(circuit_builder); + // check if the new witness equals the existing expression + izip!(u.expr(), existing_expr) + .try_for_each(|(lhs, rhs)| circuit_builder.require_equal(lhs, rhs)) + .unwrap(); + } + u.expr() + }; + + let a_expr = create_expr(self); + let b_expr = create_expr(multiplier); + + // result check + let c_expr = c.expr(); + let c_carries = c.carries.as_ref().unwrap(); + + // a_expr[0] * b_expr[0] - c_carry[0] * 2^C = c_expr[0] + circuit_builder.require_equal( + a_expr[0].clone() * b_expr[0].clone() - c_carries[0].expr() * Self::POW_OF_C.into(), + c_expr[0].clone(), + )?; + // a_expr[0] * b_expr[1] + a_expr[1] * b_expr[0] - c_carry[1] * 2^C + c_carry[0] = c_expr[1] + circuit_builder.require_equal( + a_expr[0].clone() * b_expr[0].clone() - c_carries[1].expr() * Self::POW_OF_C.into() + + c_carries[0].expr(), + c_expr[1].clone(), + )?; + // a_expr[0] * b_expr[2] + a_expr[1] * b_expr[1] + a_expr[2] * b_expr[0] - + // c_carry[2] * 2^C + c_carry[1] = c_expr[2] + circuit_builder.require_equal( + a_expr[0].clone() * b_expr[2].clone() + + a_expr[1].clone() * b_expr[1].clone() + + a_expr[2].clone() * b_expr[0].clone() + - c_carries[2].expr() * Self::POW_OF_C.into() + + c_carries[1].expr(), + c_expr[2].clone(), + )?; + // a_expr[0] * b_expr[3] + a_expr[1] * b_expr[2] + a_expr[2] * b_expr[1] + + // a_expr[3] * b_expr[0] - c_carry[3] * 2^C + c_carry[2] = c_expr[3] + circuit_builder.require_equal( + a_expr[0].clone() * b_expr[3].clone() + + a_expr[1].clone() * b_expr[2].clone() + + a_expr[2].clone() * b_expr[1].clone() + + a_expr[3].clone() * b_expr[0].clone() + - c_carries[3].expr() * Self::POW_OF_C.into() + + c_carries[2].expr(), + c_expr[3].clone(), + )?; + + // overflow check + if check_overflow { + circuit_builder.require_zero(c.carries.as_ref().unwrap().last().unwrap().expr())?; + } + + Ok(c) + } + + pub fn mul( + &mut self, + circuit_builder: &mut CircuitBuilder, + multiplier: &mut UInt, + ) -> Result, ZKVMError> { + self.internal_mul(circuit_builder, multiplier, true) + } + + pub fn mul_unsafe( + &mut self, + circuit_builder: &mut CircuitBuilder, + multiplier: &mut UInt, + ) -> Result, ZKVMError> { + self.internal_mul(circuit_builder, multiplier, true) + } + + /// Check two UInt are equal + pub fn eq( &self, circuit_builder: &mut CircuitBuilder, - rhs: &UInt, + rhs: &UInt, ) -> Result<(), ZKVMError> { izip!(self.expr(), rhs.expr()) .try_for_each(|(lhs, rhs)| circuit_builder.require_equal(lhs, rhs)) } - pub fn lt( + pub fn lt( &self, - circuit_builder: &mut CircuitBuilder, - rhs: &UInt, + _circuit_builder: &mut CircuitBuilder, + _rhs: &UInt, ) -> Result, ZKVMError> { Ok(self.expr().remove(0) + 1.into()) } } + +#[cfg(test)] +mod tests { + + mod add { + use crate::{ + circuit_builder::CircuitBuilder, expression::Expression, scheme::utils::eval_by_expr, + uint::UInt, + }; + use ff::Field; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + + type E = GoldilocksExt2; + #[test] + fn test_add_no_carries() { + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // c = 3 + 2 * 2^16 with 0 carries + let a = vec![1, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let carries = vec![0; 4]; + let witness_values = [a, b, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = UInt::<64, 16, E>::new(&mut circuit_builder); + let c = a.add(&mut circuit_builder, &b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::from(3) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::from(2) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ZERO + ); + } + + #[test] + fn test_add_w_carry() { + type E = GoldilocksExt2; + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 65535 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // c = 1 + 3 * 2^16 with carries [1, 0, 0, 0] + let a = vec![0xFFFF, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let carries = vec![1, 0, 0, 0]; + let witness_values = [a, b, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = UInt::<64, 16, E>::new(&mut circuit_builder); + let c = a.add(&mut circuit_builder, &b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::from(3) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ZERO + ); + } + + #[test] + fn test_add_w_carries() { + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 65535 + 65534 * 2^16 + // b = 2 + 1 * 2^16 + // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] + let a = vec![0xFFFF, 0xFFFE, 0, 0]; + let b = vec![2, 1, 0, 0]; + let carries = vec![1, 1, 0, 0]; + let witness_values = [a, b, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = UInt::<64, 16, E>::new(&mut circuit_builder); + let c = a.add(&mut circuit_builder, &b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ZERO + ); + } + + #[test] + fn test_add_w_overflow() { + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 1 + 1 * 2^16 + 0 + 65535 * 2^48 + // b = 2 + 1 * 2^16 + 0 + 2 * 2^48 + // c = 3 + 2 * 2^16 + 0 + 1 * 2^48 with carries [0, 0, 0, 1] + let a = vec![1, 1, 0, 0xFFFF]; + let b = vec![2, 1, 0, 2]; + let carries = vec![0, 0, 0, 1]; + let witness_values = [a, b, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = UInt::<64, 16, E>::new(&mut circuit_builder); + let c = a.add(&mut circuit_builder, &b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::from(3) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::from(2) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ONE + ); + } + + #[test] + fn test_add_const_no_carries() { + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 1 + 1 * 2^16 + // const b = 2 + // c = 3 + 1 * 2^16 with 0 carries + let a = vec![1, 1, 0, 0]; + let carries = vec![0; 4]; + let witness_values = [a, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = Expression::Constant(2.into()); + let c = a.add_const(&mut circuit_builder, b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::from(3) + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ZERO + ); + } + + #[test] + fn test_add_const_w_carries() { + let mut circuit_builder = CircuitBuilder::::new(); + + // a = 65535 + 65534 * 2^16 + // b = 2 + 1 * 2^16 + // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] + let a = vec![0xFFFF, 0xFFFE, 0, 0]; + let carries = vec![1, 1, 0, 0]; + let witness_values = [a, carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let a = UInt::<64, 16, E>::new(&mut circuit_builder); + let b = Expression::Constant(65538.into()); + let c = a.add_const(&mut circuit_builder, b).unwrap(); + + // verify limb_c[] = limb_a[] + limb_b[] + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[0]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[1]), + E::ZERO + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[2]), + E::ONE + ); + assert_eq!( + eval_by_expr(&witness_values, &challenges, &c.expr()[3]), + E::ZERO + ); + } + } + + mod mul { + use crate::{circuit_builder::CircuitBuilder, scheme::utils::eval_by_expr, uint::UInt}; + use ff_ext::ExtensionField; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + + type E = GoldilocksExt2; // 18446744069414584321 + const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; + #[test] + fn test_mul_no_carries() { + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // c = 2 + 3 * 2^16 + 1 * 2^32 = 4,295,163,906 + let wit_a = vec![1, 1, 0, 0]; + let wit_b = vec![2, 1, 0, 0]; + let wit_c = vec![2, 3, 1, 0]; + let wit_carries = vec![0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::(witness_values); + } + + #[test] + fn test_mul_w_carry() { + // a = 256 + 1 * 2^16 + // b = 257 + 1 * 2^16 + // c = 256 + 514 * 2^16 + 1 * 2^32 = 4,328,653,056 + let wit_a = vec![256, 1, 0, 0]; + let wit_b = vec![257, 1, 0, 0]; + let wit_c = vec![256, 514, 1, 0]; + let wit_carries = vec![1, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::(witness_values); + } + + #[test] + fn test_mul_w_carries() { + // a = 256 + 256 * 2^16 = 16,777,472 + // b = 257 + 256 * 2^16 = 16,777,473 + // c = 256 + 257 * 2^16 + 2 * 2^32 + 1 * 2^48 = 281,483,583,488,256 + let wit_a = vec![256, 256, 0, 0]; + let wit_b = vec![257, 256, 0, 0]; + // result = [256 * 257, 256*256 + 256*257, 256*256, 0] + // ==> [256 + 1 * (2^16), 256 + 2 * (2^16), 0 + 1 * (2^16), 0] + // so we get wit_c = [256, 256, 0, 0] and carries = [1, 2, 1, 0] + let wit_c = vec![256, 257, 2, 1]; + let wit_carries = vec![1, 2, 1, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::(witness_values); + } + + fn verify(witness_values: Vec) { + let mut circuit_builder = CircuitBuilder::::new(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let mut uint_a = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_b = UInt::<64, 16, E>::new(&mut circuit_builder); + let uint_c = uint_a.mul(&mut circuit_builder, &mut uint_b).unwrap(); + + let a = &witness_values[0..4]; + let b = &witness_values[4..8]; + let c_carries = &witness_values[12..16]; + + // limbs cal. + let t0 = a[0] * b[0] - c_carries[0] * POW_OF_C; + let t1 = a[0] * b[1] + a[1] * b[0] - c_carries[1] * POW_OF_C + c_carries[0]; + let t2 = + a[0] * b[2] + a[1] * b[1] + a[2] * b[0] - c_carries[2] * POW_OF_C + c_carries[1]; + let t3 = a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0] + - c_carries[3] * POW_OF_C + + c_carries[2]; + + // verify + let c_expr = uint_c.expr(); + let w: Vec = witness_values.iter().map(|&a| a.into()).collect_vec(); + assert_eq!(eval_by_expr(&w, &challenges, &c_expr[0]), E::from(t0)); + assert_eq!(eval_by_expr(&w, &challenges, &c_expr[1]), E::from(t1)); + assert_eq!(eval_by_expr(&w, &challenges, &c_expr[2]), E::from(t2)); + assert_eq!(eval_by_expr(&w, &challenges, &c_expr[3]), E::from(t3)); + } + } + + mod mul_add { + use crate::{circuit_builder::CircuitBuilder, scheme::utils::eval_by_expr, uint::UInt}; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + + type E = GoldilocksExt2; // 18446744069414584321 + #[test] + fn test_add_mul() { + // c = a + b + // e = c * d + + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // ==> c = 3 + 2 * 2^16 with 0 carries + // d = 1 + 1 * 2^16 + // ==> e = 3 + 5 * 2^16 + 2 * 2^32 = 8,590,262,275 + let a = vec![1, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let c_carries = vec![0; 4]; + // witness of e = c * d + let new_c = vec![3, 2, 0, 0]; + let new_c_carries = c_carries.clone(); + let d = vec![1, 1, 0, 0]; + let e = vec![3, 5, 2, 0]; + let e_carries = vec![0; 4]; + + let witness_values: Vec = [ + a, + b, + c_carries.clone(), + // e = c * d + d, + e.clone(), + e_carries.clone(), + new_c, + new_c_carries, + ] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + + let mut circuit_builder = CircuitBuilder::::new(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let uint_a = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_b = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_c = uint_a.add(&mut circuit_builder, &mut uint_b).unwrap(); + let mut uint_d = UInt::<64, 16, E>::new(&mut circuit_builder); + let uint_e = uint_c.mul(&mut circuit_builder, &mut uint_d).unwrap(); + + uint_e.expr().iter().enumerate().for_each(|(i, ret)| { + // limbs check + assert_eq!( + eval_by_expr(&witness_values, &challenges, ret), + E::from(e.clone()[i]) + ); + }); + } + + #[test] + fn test_add_mul2() { + // c = a + b + // f = d + e + // g = c * f + + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // ==> c = 3 + 2 * 2^16 with 0 carries + // d = 1 + 1 * 2^16 + // e = 2 + 1 * 2^16 + // ==> f = 3 + 2 * 2^16 with 0 carries + // ==> e = 9 + 12 * 2^16 + 4 * 2^32 = 17,180,655,625 + let a = vec![1, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let c_carries = vec![0; 4]; + // witness of g = c * f + let new_c = vec![3, 2, 0, 0]; + let new_c_carries = c_carries.clone(); + let g = vec![9, 12, 4, 0]; + let g_carries = vec![0; 4]; + + let witness_values: Vec = [ + // c = a + b + a.clone(), + b.clone(), + c_carries.clone(), + // f = d + e + a, + b, + c_carries.clone(), + // g = c * f + g.clone(), + g_carries, + new_c.clone(), + new_c_carries.clone(), + new_c, + new_c_carries, + ] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + + let mut circuit_builder = CircuitBuilder::::new(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let uint_a = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_b = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_c = uint_a.add(&mut circuit_builder, &mut uint_b).unwrap(); + let uint_d = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_e = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_f = uint_d.add(&mut circuit_builder, &mut uint_e).unwrap(); + let uint_g = uint_c.mul(&mut circuit_builder, &mut uint_f).unwrap(); + + uint_g.expr().iter().enumerate().for_each(|(i, ret)| { + // limbs check + assert_eq!( + eval_by_expr(&witness_values, &challenges, ret), + E::from(g.clone()[i]) + ); + }); + } + + #[test] + fn test_mul_add() { + // c = a * b + // e = c + d + + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // ==> c = 2 + 3 * 2^16 + 1 * 2^32 + // d = 1 + 1 * 2^16 + // ==> e = 3 + 4 * 2^16 + 1 * 2^32 + let a = vec![1, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let c = vec![2, 3, 1, 0]; + let c_carries = vec![0; 4]; + // e = c + d + let d = vec![1, 1, 0, 0]; + let e = vec![3, 4, 1, 0]; + let e_carries = vec![0; 4]; + + let witness_values: Vec = [a, b, c, c_carries, d, e_carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + + let mut circuit_builder = CircuitBuilder::::new(); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let mut uint_a = UInt::<64, 16, E>::new(&mut circuit_builder); + let mut uint_b = UInt::<64, 16, E>::new(&mut circuit_builder); + let uint_c = uint_a.mul(&mut circuit_builder, &mut uint_b).unwrap(); + let mut uint_d = UInt::<64, 16, E>::new(&mut circuit_builder); + let uint_e = uint_c.add(&mut circuit_builder, &mut uint_d).unwrap(); + + uint_e.expr().iter().enumerate().for_each(|(i, ret)| { + // limbs check + assert_eq!( + eval_by_expr(&witness_values, &challenges, ret), + E::from(e.clone()[i]) + ); + }); + } + } +} diff --git a/ceno_zkvm/src/uint/constants.rs b/ceno_zkvm/src/uint/constants.rs index 271bbc8aa..94bdbc93e 100644 --- a/ceno_zkvm/src/uint/constants.rs +++ b/ceno_zkvm/src/uint/constants.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use crate::utils::const_min; use super::UInt; @@ -7,7 +5,9 @@ use super::UInt; pub const RANGE_CHIP_BIT_WIDTH: usize = 16; pub const BYTE_BIT_WIDTH: usize = 8; -impl UInt { +use ff_ext::ExtensionField; + +impl UInt { pub const M: usize = M; pub const C: usize = C; @@ -18,56 +18,13 @@ impl UInt { /// but if M >= C then maximum_usable_cell_capacity = C pub const MAX_CELL_BIT_WIDTH: usize = const_min(M, C); - /// `N_OPERAND_CELLS` represent the minimum number of cells each of size `C` needed + /// `NUM_CELLS` represent the minimum number of cells each of size `C` needed /// to hold `M` total bits - pub const N_OPERAND_CELLS: usize = (M + C - 1) / C; + pub const NUM_CELLS: usize = (M + C - 1) / C; /// The number of `RANGE_CHIP_BIT_WIDTH` cells needed to represent one cell of size `C` const N_RANGE_CELLS_PER_CELL: usize = (C + RANGE_CHIP_BIT_WIDTH - 1) / RANGE_CHIP_BIT_WIDTH; /// The number of `RANGE_CHIP_BIT_WIDTH` cells needed to represent the entire `UInt` - pub const N_RANGE_CELLS: usize = Self::N_OPERAND_CELLS * Self::N_RANGE_CELLS_PER_CELL; -} - -/// Holds addition specific constants -pub struct AddSubConstants { - _marker: PhantomData, -} - -impl AddSubConstants> { - /// Number of cells required to track carry information for the addition operation. - /// operand_0 = a b c - /// operand_1 = e f g - /// ---------- - /// result = h i j - /// carry = k l m - - /// |Carry| = |Cells| - pub const N_CARRY_CELLS: usize = UInt::::N_OPERAND_CELLS; - - /// Number of cells required to track carry information if we assume the addition - /// operation cannot lead to overflow. - /// operand_0 = a b c - /// operand_1 = e f g - /// ---------- - /// result = h i j - /// carry = l m - - /// |Carry| = |Cells - 1| - const N_CARRY_CELLS_NO_OVERFLOW: usize = Self::N_CARRY_CELLS - 1; - - /// The size of the witness - pub const N_WITNESS_CELLS: usize = UInt::::N_RANGE_CELLS + Self::N_CARRY_CELLS; - - /// The size of the witness assuming carry has no overflow - /// |Range_values| + |Carry - 1| - pub const N_WITNESS_CELLS_NO_CARRY_OVERFLOW: usize = - UInt::::N_RANGE_CELLS + Self::N_CARRY_CELLS_NO_OVERFLOW; - - pub const N_NO_OVERFLOW_WITNESS_UNSAFE_CELLS: usize = Self::N_CARRY_CELLS_NO_OVERFLOW; - - /// The number of `RANGE_CHIP_BIT_WIDTH` cells needed to represent the carry cells, assuming - /// no overflow. - // TODO: if guaranteed no overflow, then we don't need to range check the highest limb - // hence this can be (N_OPERANDS - 1) * N_RANGE_CELLS_PER_CELL - // update this once, range check logic doesn't assume all limbs - pub const N_RANGE_CELLS_NO_OVERFLOW: usize = UInt::::N_RANGE_CELLS; + pub const N_RANGE_CELLS: usize = Self::NUM_CELLS * Self::N_RANGE_CELLS_PER_CELL; } diff --git a/ceno_zkvm/src/uint/uint.rs b/ceno_zkvm/src/uint/uint.rs deleted file mode 100644 index a14e1f496..000000000 --- a/ceno_zkvm/src/uint/uint.rs +++ /dev/null @@ -1,278 +0,0 @@ -use crate::{ - circuit_builder::CircuitBuilder, - error::UtilError, - expression::{Expression, ToExpr, WitIn}, - utils::add_one_to_big_num, -}; -use ff_ext::ExtensionField; -use goldilocks::SmallField; -use sumcheck::util::ceil_log2; - -use super::constants::BYTE_BIT_WIDTH; - -#[derive(Clone)] -/// Unsigned integer with `M` total bits. `C` denotes the cell bit width. -/// Represented in little endian form. -pub struct UInt { - pub values: Vec, -} - -impl UInt { - pub fn new(circuit_builder: &mut CircuitBuilder) -> Self { - Self { - values: (0..Self::N_OPERAND_CELLS) - .map(|_| circuit_builder.create_witin()) - .collect(), - } - } - - pub fn expr(&self) -> Vec> { - self.values - .iter() - .map(ToExpr::expr) - .collect::>>() - } - - /// Return the `UInt` underlying cell id's - pub fn wits_in(&self) -> &[WitIn] { - &self.values - } - - /// Builds a `UInt` instance from a set of cells that represent `RANGE_VALUES` - /// assumes range_values are represented in little endian form - pub fn from_range_wits_in( - circuit_builder: &mut CircuitBuilder, - range_values: &[WitIn], - ) -> Result { - // Self::from_different_sized_cell_values( - // circuit_builder, - // range_values, - // RANGE_CHIP_BIT_WIDTH, - // true, - // ) - todo!() - } - - /// Builds a `UInt` instance from a set of cells that represent big-endian `BYTE_VALUES` - pub fn from_bytes_big_endian( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - ) -> Result { - Self::from_bytes(circuit_builder, bytes, false) - } - - /// Builds a `UInt` instance from a set of cells that represent little-endian `BYTE_VALUES` - pub fn from_bytes_little_endian( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - ) -> Result { - Self::from_bytes(circuit_builder, bytes, true) - } - - /// Builds a `UInt` instance from a set of cells that represent `BYTE_VALUES` - pub fn from_bytes( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - is_little_endian: bool, - ) -> Result { - Self::from_different_sized_cell_values( - circuit_builder, - bytes, - BYTE_BIT_WIDTH, - is_little_endian, - ) - } - - /// Builds a `UInt` instance from a set of cell values of a certain `CELL_WIDTH` - fn from_different_sized_cell_values( - circuit_builder: &mut CircuitBuilder, - wits_in: &[WitIn], - cell_width: usize, - is_little_endian: bool, - ) -> Result { - todo!() - // let mut values = convert_decomp( - // circuit_builder, - // wits_in, - // cell_width, - // Self::MAX_CELL_BIT_WIDTH, - // is_little_endian, - // )?; - // debug_assert!(values.len() <= Self::N_OPERAND_CELLS); - // pad_cells(circuit_builder, &mut values, Self::N_OPERAND_CELLS); - // values.try_into() - } - - /// Generate ((0)_{2^C}, (1)_{2^C}, ..., (size - 1)_{2^C}) - pub fn counter_vector(size: usize) -> Vec> { - let num_vars = ceil_log2(size); - let number_of_limbs = (num_vars + C - 1) / C; - let cell_modulo = F::from(1 << C); - - let mut res = vec![vec![F::ZERO; number_of_limbs]]; - - for i in 1..size { - res.push(add_one_to_big_num(cell_modulo, &res[i - 1])); - } - - res - } -} - -/// Construct `UInt` from `Vec` -impl TryFrom> for UInt { - type Error = UtilError; - - fn try_from(values: Vec) -> Result { - if values.len() != Self::N_OPERAND_CELLS { - return Err(UtilError::UIntError(format!( - "cannot construct UInt<{}, {}> from {} cells, requires {} cells", - M, - C, - values.len(), - Self::N_OPERAND_CELLS - ))); - } - - Ok(Self { values }) - } -} - -/// Construct `UInt` from `$[CellId]` -impl TryFrom<&[WitIn]> for UInt { - type Error = UtilError; - - fn try_from(values: &[WitIn]) -> Result { - values.to_vec().try_into() - } -} - -// #[cfg(test)] -// mod tests { -// use crate::uint::uint::UInt; -// use gkr::structs::{Circuit, CircuitWitness}; -// use goldilocks::{Goldilocks, GoldilocksExt2}; -// use itertools::Itertools; -// use simple_frontend::structs::CircuitBuilder; - -// #[test] -// fn test_uint_from_cell_ids() { -// // 33 total bits and each cells holds just 4 bits -// // to hold all 33 bits without truncations, we'd need 9 cells -// // 9 * 4 = 36 > 33 -// type UInt33 = UInt<33, 4>; -// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).is_ok()); -// assert!(UInt33::try_from(vec![1, 2, 3]).is_err()); -// assert!(UInt33::try_from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).is_err()); -// } - -// #[test] -// fn test_uint_from_different_sized_cell_values() { -// // build circuit -// let mut circuit_builder = CircuitBuilder::::new(); -// let (_, small_values) = circuit_builder.create_witness_in(8); -// type UInt30 = UInt<30, 6>; -// let uint_instance = -// UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true) -// .unwrap(); -// circuit_builder.configure(); -// let circuit = Circuit::new(&circuit_builder); - -// // input -// // we start with cells of bit width 2 (8 of them) -// // 11 00 10 11 01 10 01 01 (bit representation) -// // 3 0 2 3 1 2 1 1 (field representation) -// // -// // repacking into cells of bit width 6 -// // 110010 110110 010100 -// // since total bit = 30 then expect 5 cells ( 30 / 6) -// // since we have 3 cells, we need to pad with 2 more -// // hence expected output: -// // 100011 100111 000101 000000 000000(bit representation) -// // 35 39 5 0 0 - -// let witness_values = vec![3, 0, 2, 3, 1, 2, 1, 1] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec(); -// let circuit_witness = { -// let challenges = vec![GoldilocksExt2::from(2)]; -// let mut circuit_witness = CircuitWitness::new(&circuit, challenges); -// circuit_witness.add_instance(&circuit, vec![witness_values]); -// circuit_witness -// }; -// circuit_witness.check_correctness(&circuit); - -// let output = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); -// assert_eq!( -// &output[..5], -// vec![35, 39, 5, 0, 0] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec() -// ); - -// // padding to power of 2 -// assert_eq!( -// &output[5..], -// vec![0, 0, 0] -// .into_iter() -// .map(|v| Goldilocks::from(v)) -// .collect_vec() -// ); -// } - -// #[test] -// fn test_counter_vector() { -// // each limb has 5 bits so all number from 0..3 should require only 1 limb -// type UInt30 = UInt<30, 5>; -// let res = UInt30::counter_vector::(3); -// assert_eq!( -// res, -// vec![ -// vec![Goldilocks::from(0)], -// vec![Goldilocks::from(1)], -// vec![Goldilocks::from(2)] -// ] -// ); - -// // each limb has a single bit, number from 0..5 should require 3 limbs each -// type UInt50 = UInt<50, 1>; -// let res = UInt50::counter_vector::(5); -// assert_eq!( -// res, -// vec![ -// // 0 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(0), -// Goldilocks::from(0) -// ], -// // 1 -// vec![ -// Goldilocks::from(1), -// Goldilocks::from(0), -// Goldilocks::from(0) -// ], -// // 2 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(1), -// Goldilocks::from(0) -// ], -// // 3 -// vec![ -// Goldilocks::from(1), -// Goldilocks::from(1), -// Goldilocks::from(0) -// ], -// // 4 -// vec![ -// Goldilocks::from(0), -// Goldilocks::from(0), -// Goldilocks::from(1) -// ], -// ] -// ); -// } -// } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 76c9c690e..27eae27a5 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -31,6 +31,7 @@ pub(crate) fn add_one_to_big_num(limb_modulo: F, limbs: &[F]) -> Vec(x: i64) -> E::BaseField { if x >= 0 { E::BaseField::from(x as u64) diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index 4c4b4dcac..a79bdf935 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -8,6 +8,8 @@ license.workspace = true [dependencies] tracing = "0.1.40" +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing-flame = "0.2.0" ff_ext = { path = "../ff_ext" } ark-std.workspace = true ff.workspace = true @@ -15,5 +17,6 @@ goldilocks.workspace = true rayon.workspace = true serde.workspace = true + [features] parallel = [ ]