From 1ca8a43f39e735295bf1abd508084ad8e4f7fb26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Sun, 27 Oct 2024 18:53:08 +0800 Subject: [PATCH 1/3] Implement `IntoIterator` and use it --- ceno_zkvm/src/uint.rs | 24 ++++++++++++++++++++++++ ceno_zkvm/src/uint/arithmetic.rs | 18 +++++++----------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index fef0c80bc..53f0e86cd 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -34,6 +34,30 @@ pub enum UintLimb { Expression(Vec>), } +impl IntoIterator for UintLimb { + type Item = WitIn; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + match self { + UintLimb::WitIn(wits) => wits.into_iter(), + _ => unimplemented!(), + } + } +} + +impl<'a, E: ExtensionField> IntoIterator for &'a UintLimb { + type Item = &'a WitIn; + type IntoIter = std::slice::Iter<'a, WitIn>; + + fn into_iter(self) -> Self::IntoIter { + match self { + UintLimb::WitIn(wits) => wits.iter(), + _ => unimplemented!(), + } + } +} + impl UintLimb { pub fn iter(&self) -> impl Iterator { match self { diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 3ce3b65f6..02e4ffe59 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -1,6 +1,6 @@ use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::{Itertools, izip}; +use itertools::{Itertools, enumerate, izip}; use super::{UIntLimbs, UintLimb}; use crate::{ @@ -364,16 +364,12 @@ impl UIntLimbs { .collect::, ZKVMError>>()?; // check byte diff that before the first non-zero i_0 equals zero - si.iter() - .zip(self.limbs.iter()) - .zip(rhs.limbs.iter()) - .enumerate() - .try_for_each(|(i, ((flag, a), b))| { - circuit_builder.require_zero( - || format!("byte diff {i} zero check"), - a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(), - ) - })?; + enumerate(izip!(&si, &self.limbs, &rhs.limbs)).try_for_each(|(i, (flag, a, b))| { + circuit_builder.require_zero( + || format!("byte diff {i} zero check"), + a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(), + ) + })?; // define accumulated byte sum // when a!= b, sa should equal the first non-zero byte a[i_0] From d3906138ccc6db466a9aa4a890ea3b905ed5bd89 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Tue, 5 Nov 2024 12:26:49 +0800 Subject: [PATCH 2/3] Test --- ceno_zkvm/src/uint.rs | 18 +++++++++--------- ceno_zkvm/src/uint/arithmetic.rs | 14 ++++++-------- ceno_zkvm/src/uint/logic.rs | 2 +- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 53480d1ee..ae31bdb37 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -18,7 +18,7 @@ use constants::BYTE_BIT_WIDTH; use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::Itertools; +use itertools::{enumerate, Itertools}; use std::{ borrow::Cow, mem::{self, MaybeUninit}, @@ -60,7 +60,7 @@ impl<'a, E: ExtensionField> IntoIterator for &'a UintLimb { } impl UintLimb { - pub fn iter(&self) -> impl Iterator { + pub fn iter_fnord(&self) -> impl Iterator { match self { UintLimb::WitIn(vec) => vec.iter(), _ => unimplemented!(), @@ -310,9 +310,9 @@ impl UIntLimbs { (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() << 8)); shift_pows }; - let combined_limbs = x - .limbs - .iter() + let combined_limbs = (&x + .limbs) + .into_iter() .collect_vec() .chunks(k) .map(|chunk| { @@ -342,7 +342,7 @@ impl UIntLimbs { }; let split_limbs = x .limbs - .iter() + .iter_fnord() .flat_map(|large_limb| { let limbs = (0..k) .map(|_| { @@ -564,7 +564,7 @@ impl UInt { &self, cb: &mut CircuitBuilder, ) -> Result, ZKVMError> { - SignedExtendConfig::::construct_limb(cb, self.limbs.iter().last().unwrap().expr()) + SignedExtendConfig::::construct_limb(cb, (&self.limbs).into_iter().last().unwrap().expr()) } } @@ -825,8 +825,8 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { let mut c_limbs = vec![0u16; num_limbs]; let mut carries = vec![0u64; num_limbs]; let mut tmp = vec![0u64; num_limbs]; - a_limbs.iter().enumerate().for_each(|(i, &a_limb)| { - b_limbs.iter().enumerate().for_each(|(j, &b_limb)| { + enumerate(a_limbs).for_each(|(i, &a_limb)| { + enumerate(b_limbs).for_each(|(j, &b_limb)| { let idx = i + j; if idx < num_limbs { tmp[idx] += a_limb as u64 * b_limb as u64; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 85eb703c8..dfe33b076 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -267,14 +267,12 @@ impl UIntLimbs { rhs: &UIntLimbs, ) -> Result { let n_limbs = Self::NUM_LIMBS; - let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = self - .limbs - .iter() - .zip_eq(rhs.limbs.iter()) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) - .collect::, ZKVMError>>()? - .into_iter() - .unzip(); + let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = + izip!(&self.limbs, &rhs.limbs) + .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .collect::, ZKVMError>>()? + .into_iter() + .unzip(); let sum_expr = is_equal_per_limb .iter() diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index b340df982..024d09d73 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -18,7 +18,7 @@ impl UIntLimbs { b: &Self, c: &Self, ) -> Result<(), ZKVMError> { - for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { + for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; } Ok(()) From 4ba61c9396de9ef41a2732fe82fc6b2215cd2a8a Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Tue, 5 Nov 2024 12:33:53 +0800 Subject: [PATCH 3/3] Minimise diff --- ceno_zkvm/src/uint.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index ae31bdb37..fe7ff8ae0 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -18,7 +18,7 @@ use constants::BYTE_BIT_WIDTH; use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; -use itertools::{enumerate, Itertools}; +use itertools::{Itertools, enumerate}; use std::{ borrow::Cow, mem::{self, MaybeUninit}, @@ -60,11 +60,8 @@ impl<'a, E: ExtensionField> IntoIterator for &'a UintLimb { } impl UintLimb { - pub fn iter_fnord(&self) -> impl Iterator { - match self { - UintLimb::WitIn(vec) => vec.iter(), - _ => unimplemented!(), - } + pub fn iter(&self) -> impl Iterator { + self.into_iter() } } @@ -310,9 +307,9 @@ impl UIntLimbs { (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() << 8)); shift_pows }; - let combined_limbs = (&x - .limbs) - .into_iter() + let combined_limbs = x + .limbs + .iter() .collect_vec() .chunks(k) .map(|chunk| { @@ -342,7 +339,7 @@ impl UIntLimbs { }; let split_limbs = x .limbs - .iter_fnord() + .iter() .flat_map(|large_limb| { let limbs = (0..k) .map(|_| { @@ -564,7 +561,7 @@ impl UInt { &self, cb: &mut CircuitBuilder, ) -> Result, ZKVMError> { - SignedExtendConfig::::construct_limb(cb, (&self.limbs).into_iter().last().unwrap().expr()) + SignedExtendConfig::::construct_limb(cb, self.limbs.iter().last().unwrap().expr()) } }