From 7eb11c1250133d8554bd4374d9753803537a6ed4 Mon Sep 17 00:00:00 2001 From: Nicolas Gailly Date: Wed, 17 Jan 2024 16:02:21 +0100 Subject: [PATCH] Feat/keccak trait to main (#28) * moving array access to own file * wip * adding benchmark + test * fmt * no logging on CI * adding repeated poseidon * csv output * adding keccak and benchmark * keccak added * safe init logging --- src/benches/mod.rs | 93 +++++++++++++- src/benches/recursion.rs | 270 ++++++++++++++++++++++++++++----------- src/keccak.rs | 239 ++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 4 files changed, 526 insertions(+), 77 deletions(-) create mode 100644 src/keccak.rs diff --git a/src/benches/mod.rs b/src/benches/mod.rs index f78b59057..21a8ab43e 100644 --- a/src/benches/mod.rs +++ b/src/benches/mod.rs @@ -1,6 +1,5 @@ -use std::env; - use log::{log_enabled, Level, LevelFilter}; +use std::env; use std::io::Write; mod array_access; #[cfg(test)] @@ -11,9 +10,95 @@ mod recursion; pub(crate) fn init_logging() { if !log_enabled!(Level::Debug) { env::set_var("RUST_LOG", "debug"); - env_logger::builder() + let _ = env_logger::builder() .format(|buf, record| writeln!(buf, " {}", record.args())) - .init(); + .try_init(); log::set_max_level(LevelFilter::Debug); } } + +#[cfg(test)] +mod test { + use plonky2::{ + field::extension::Extendable, + hash::hash_types::RichField, + iop::witness::PartialWitness, + plonk::{ + circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, config::GenericConfig, + }, + }; + use serde::Serialize; + use std::time; + + use crate::{circuit::UserCircuit, utils::verify_proof_tuple}; + + #[derive(Serialize, Clone, Debug)] + pub(crate) struct BenchResult { + pub circuit: String, + // n is circuit dependent + pub n: usize, + // arity is 0 when it's not recursive, 1 when ivc and more for PCD + pub arity: usize, + pub gate_count: usize, + pub building: u64, + pub proving: u64, + pub lde: usize, + pub verifying: u64, + } + + pub fn run_benchs(fname: String, benches: Vec BenchResult>>) { + let mut writer = csv::Writer::from_path(fname).unwrap(); + for bench in benches { + let result = bench(); + writer.serialize(result).unwrap(); + writer.flush().unwrap(); + } + } + + pub trait Benchable { + // returns the relevant information depending on the circuit being benchmarked + // i.e. n can be the number of times we hash some fixed length data + fn n(&self) -> usize { + 0 + } + } + + pub fn bench_simple_circuit< + F, + const D: usize, + C: GenericConfig, + U: UserCircuit + Benchable, + >( + tname: String, + u: U, + ) -> BenchResult + where + F: RichField + Extendable, + { + let mut b = CircuitBuilder::new(CircuitConfig::standard_recursion_config()); + let mut pw = PartialWitness::new(); + let now = time::Instant::now(); + let wires = U::build(&mut b); + let gate_count = b.num_gates(); + let circuit_data = b.build::(); + let building_time = now.elapsed(); + let now = time::Instant::now(); + u.prove(&mut pw, &wires); + let proof = circuit_data.prove(pw).expect("invalid proof"); + let proving_time = now.elapsed(); + let lde = circuit_data.common.lde_size(); + let now = time::Instant::now(); + verify_proof_tuple(&(proof, circuit_data.verifier_only, circuit_data.common)).unwrap(); + let verifying_time = now.elapsed(); + BenchResult { + circuit: tname, + gate_count, + n: u.n(), + arity: 0, + lde, + building: building_time.as_millis() as u64, + proving: proving_time.as_millis() as u64, + verifying: verifying_time.as_millis() as u64, + } + } +} diff --git a/src/benches/recursion.rs b/src/benches/recursion.rs index 0f52e9d95..d97703ef9 100644 --- a/src/benches/recursion.rs +++ b/src/benches/recursion.rs @@ -1,8 +1,10 @@ +use crate::benches::test::{bench_simple_circuit, run_benchs, BenchResult}; use crate::circuit::{NoopCircuit, ProofOrDummyTarget}; +use crate::keccak::{self, KeccakWires}; use itertools::Itertools; use log::info; use plonky2::field::types::Sample; -use plonky2::hash::poseidon::Poseidon; +use plonky2::gates::exponentiation::ExponentiationGate; use plonky2::{ field::extension::Extendable, hash::{ @@ -16,16 +18,16 @@ use plonky2::{ }, plonk::{ circuit_builder::CircuitBuilder, - circuit_data::CircuitConfig, config::{AlgebraicHasher, GenericConfig, GenericHashOut, PoseidonGoldilocksConfig}, proof::ProofWithPublicInputs, }, }; -use serde::Serialize; +use rand::Rng; use super::init_logging; +use super::test::Benchable; +use crate::circuit::CyclicCircuit; use crate::circuit::{PCDCircuit, Padder, UserCircuit}; -use crate::{circuit::CyclicCircuit, utils::verify_proof_tuple}; use std::{iter, time}; /// Circuit hashing ELEMS field elements into a standard Poseidon 256 bit output @@ -91,13 +93,11 @@ where pw.set_hash_target(wires.outputs, output); } } - -macro_rules! timeit { - ($a:expr) => {{ - let now = time::Instant::now(); - $a; - now.elapsed() - }}; +impl Benchable for NoopCircuit {} +impl Benchable for PoseidonCircuit { + fn n(&self) -> usize { + ELEMS + } } const D: usize = 2; @@ -108,13 +108,13 @@ type F = >::F; fn bench_recursion_noop() { #[cfg(not(ci))] init_logging(); - let tname = |i| format!("pcd_recursion_noop"); + let tname = |_| format!("pcd_recursion_noop"); macro_rules! bench_pcd { ($( $a:expr),+) => { { let mut fns : Vec BenchResult>> = vec![]; let step_fn = || NoopCircuit::new(); $( - let padder = |b: &mut CircuitBuilder| { + let padder = |_: &mut CircuitBuilder| { match $a { 1 => 12, 2 => 13, @@ -206,80 +206,204 @@ fn bench_simple_repeated_poseidon() { ); } -#[derive(Serialize, Clone, Debug)] -struct BenchResult { - circuit: String, - // n is circuit dependent - n: usize, - // arity is 0 when it's not recursive, 1 when ivc and more for PCD - arity: usize, - gate_count: usize, - building: u64, - proving: u64, - lde: usize, - verifying: u64, +fn rand_arr(size: usize) -> Vec { + (0..size) + .map(|_| rand::thread_rng().gen()) + .collect::>() } -fn run_benchs(fname: String, benches: Vec BenchResult>>) { - let mut writer = csv::Writer::from_path(fname).unwrap(); - for bench in benches { - let result = bench(); - writer.serialize(result).unwrap(); - writer.flush().unwrap(); +use crate::keccak::KeccakCircuit; +impl Benchable for KeccakCircuit { + fn n(&self) -> usize { + BYTES } } +#[derive(Clone, Debug)] +struct RepeatedKeccak { + circuits: [KeccakCircuit; N], +} -pub trait Benchable { - // returns the relevant information depending on the circuit being benchmarked - // i.e. n can be the number of times we hash some fixed length data +impl Benchable for RepeatedKeccak { fn n(&self) -> usize { - 0 + N } } -impl Benchable for NoopCircuit {} -impl Benchable for PoseidonCircuit { - fn n(&self) -> usize { - ELEMS +impl UserCircuit + for RepeatedKeccak +where + F: RichField + Extendable, +{ + type Wires = [KeccakWires; N]; + fn build(c: &mut CircuitBuilder) -> Self::Wires { + (0..N) + .map(|_| KeccakCircuit::::build(c)) + .collect::>() + .try_into() + .unwrap() + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + for (i, circuit) in self.circuits.iter().enumerate() { + circuit.prove(pw, &wires[i]); + } } } -fn bench_simple_circuit< - F, - const D: usize, - C: GenericConfig, - U: UserCircuit + Benchable, ->( - tname: String, - u: U, -) -> BenchResult + +// D = comes from plonky2 +// ARITY = of the PCD graph +// BYTES = number of bytes hashed +// N = number of times we repeat the hashing in circuit +impl + PCDCircuit for RepeatedKeccak where F: RichField + Extendable, { - let mut b = CircuitBuilder::new(CircuitConfig::standard_recursion_config()); - let mut pw = PartialWitness::new(); - let now = time::Instant::now(); - let wires = U::build(&mut b); - let gate_count = b.num_gates(); - let circuit_data = b.build::(); - let building_time = now.elapsed(); - let now = time::Instant::now(); - u.prove(&mut pw, &wires); - let proof = circuit_data.prove(pw).expect("invalid proof"); - let proving_time = now.elapsed(); - let lde = circuit_data.common.lde_size(); - let verifying_time = - timeit!( - verify_proof_tuple(&(proof, circuit_data.verifier_only, circuit_data.common)).unwrap() - ); - BenchResult { - circuit: tname, - gate_count, - n: u.n(), - arity: 0, - lde, - building: building_time.as_millis() as u64, - proving: proving_time.as_millis() as u64, - verifying: verifying_time.as_millis() as u64, + // TODO: remove assumption about public inputs, in this case we don't + // need to expose as pub inputs all the intermediate hashing + fn base_inputs(&self) -> Vec { + (0..N).flat_map(|_| F::rand_vec(8)).collect() } + fn build_recursive( + b: &mut CircuitBuilder, + p: &[ProofOrDummyTarget; ARITY], + ) -> Self::Wires { + let mut wires = vec![]; + for _ in 0..N { + wires.push(KeccakCircuit::::build_recursive(b, p)); + } + wires.try_into().unwrap() + } + fn num_io() -> usize { + let one = as PCDCircuit>::num_io(); + one * N + } +} + +/// This creates a different circuits that hash some data, different number +/// of times. This is to emulate verifying a MPT proof, where one needs +/// to consecutively hash nodes on the path from leaf to the root. +#[test] +fn bench_keccak_repeated() { + const DATA_LEN: usize = 544; + const BYTES: usize = keccak::compute_size_with_padding(DATA_LEN); + // the whole reason we need these macros is to be able to declare "like at runtime" + // an array of things with different constants inside. We can't have "const" value + // on iteration so we have to use macro to avoid hardcoding them one by one. + macro_rules! keccak_circuit { + ($( $a:expr),+) => { + { + let tname = |i| format!("repeated-keccak-n{}-b{}", i, BYTES); + let single_circuit = KeccakCircuit::::new(rand_arr(DATA_LEN)).unwrap(); + let mut fns : Vec BenchResult>> = vec![]; + $( + let name = tname($a); + let circuit = single_circuit.clone(); + fns.push(Box::new(move || { + bench_simple_circuit::( + name, + RepeatedKeccak { + circuits: [circuit; $a], + }, + ) + })); + )+ + fns + } + } + } + let fns2 = keccak_circuit!(2, 3); + run_benchs("bench_keccak_repeated.csv".to_string(), fns2); +} + +/// Launch a benchmark that runs an unified circuit that does: +/// 1. Keccak256() of some fixed data +/// 2. Verify N proofs where N varies in the experiments +/// This is to simulate the case where we want to have one unified circuit during +/// our recursion. +#[test] +fn bench_recursion_single_circuit() { + const DATA_LEN: usize = 544; + const BYTES: usize = keccak::compute_size_with_padding(DATA_LEN); + init_logging(); + let tname = |_| format!("pcd_single_circuit_keccak"); + macro_rules! bench_pcd { + ($( $a:expr),+) => { { + let mut fns : Vec BenchResult>> = vec![]; + let step_fn = || KeccakCircuit::::new(rand_arr(DATA_LEN)).unwrap(); + $( + let padder = |b: &mut CircuitBuilder| { + KeccakCircuit::<200>::build(b); + match $a { + 1 | 2 => 15, + 4 => 15, + 8 => 16, + 12 => 16, + 16 => { + b.add_gate(ExponentiationGate::new(66), vec![]); + 17 + } + _ => panic!("unrecognozed size"), + }}; + fns.push(Box::new(move || { + // arity changing but with same number of work at each step + bench_pcd_circuit::(tname($a), $a, step_fn,padder) + })); + )+ + fns + } + }; + } + let trials = bench_pcd!(1, 2); + run_benchs("bench_recursion_single_circuit.csv".to_string(), trials); +} + +/// Bench a circuit that does: +/// 1. Verify a proof recursively +/// 2. Make N consecutive hashing of fixed length +/// This is to emulate the cirucit where we can prove the update of a leaf +/// of a leaf in the merkle tree. So N must be dividable by 2, so length of +/// a proof, is N/2. +#[test] +fn bench_recursive_update_keccak() { + const DATA_LEN: usize = 544; + const BYTES: usize = keccak::compute_size_with_padding(DATA_LEN); + init_logging(); + let tname = |_| format!("pcd_single_circuit_keccak"); + let single_circuit = KeccakCircuit::::new(rand_arr(DATA_LEN)).unwrap(); + macro_rules! bench_pcd { + ($( $a:expr),+) => { { + let mut fns : Vec BenchResult>> = vec![]; + $( + let padder = |b: &mut CircuitBuilder|{ + KeccakCircuit::<200>::build(b); + match $a { + 1 => 15, + 2 => 16, + 4 => { + b.add_gate(ExponentiationGate::new(66), vec![]); + 17 + }, + 8..=14 => { + b.add_gate(ExponentiationGate::new(66), vec![]); + 18 + }, + _ => panic!("unrecognized size - fill manually"), + } + }; + fns.push(Box::new(move || { + // always 1 arity because we only verify one proof + // but verify multiple hashes + bench_pcd_circuit::(tname($a), 1, || RepeatedKeccak:: { + circuits: [single_circuit; $a] + },padder) + })); + )+ + fns + } + }; + } + let trials = bench_pcd!(1, 2); + run_benchs("bench_recursive_update_keccak.csv".to_string(), trials); } fn bench_pcd_circuit< diff --git a/src/keccak.rs b/src/keccak.rs new file mode 100644 index 000000000..c4f36baad --- /dev/null +++ b/src/keccak.rs @@ -0,0 +1,239 @@ +use anyhow::{ensure, Result}; +use plonky2::{ + field::extension::Extendable, + hash::hash_types::RichField, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, + util::ceil_div_usize, +}; +use plonky2_crypto::{ + biguint::BigUintTarget, + hash::{ + keccak256::{CircuitBuilderHashKeccak, KECCAK256_R}, + HashInputTarget, + }, + u32::arithmetic_u32::U32Target, +}; + +use crate::{ + circuit::UserCircuit, + utils::{convert_u8_to_u32, less_than, IntTargetWriter}, +}; + +/// Keccak pads data before "hashing" it. This method returns the full size +/// of the padded data before hashing. This is useful to know the actual number +/// of allocated wire one needs to reserve inside the circuit. +pub const fn compute_size_with_padding(data_len: usize) -> usize { + let input_len_bits = data_len * 8; // only pad the data that is inside the fixed buffer + let num_actual_blocks = 1 + input_len_bits / KECCAK256_R; + let padded_len_bits = num_actual_blocks * KECCAK256_R; + // reason why ^: this is annoying to do in circuit. + ceil_div_usize(padded_len_bits, 8) +} + +/// This returns only the amount of padding applied on top of the data. +pub const fn compute_padding_size(data_len: usize) -> usize { + compute_size_with_padding(data_len) - data_len +} +#[derive(Clone, Copy, Debug)] +pub struct KeccakCircuit { + data: [u8; N], + unpadded_len: usize, +} +#[derive(Clone, Debug)] +pub struct KeccakWires { + input_array: ArrayWire, + diff: Target, + // 256/u32 = 8 + output_array: [Target; 8], +} + +#[derive(Debug, Clone)] +struct ArrayWire { + arr: [Target; N], + real_len: Target, +} +impl KeccakCircuit { + pub fn new(mut data: Vec) -> Result { + let total = compute_size_with_padding(data.len()); + ensure!(total <= N, "{}bytes can't fit in {} with padding", total, N); + // NOTE we don't pad anymore because we enforce that the resulting length is already a multiple + // of 4 so it will fit the conversion to u32 and circuit vk would stay the same for different + // data length + ensure!( + N % 4 == 0, + "Fixed array size must be 0 mod 4 for conversion with u32" + ); + + let unpadded_len = data.len(); + data.resize(N, 0); + Ok(Self { + data: data.try_into().unwrap(), + unpadded_len, + }) + } + + fn build_from_array, const D: usize>( + b: &mut CircuitBuilder, + a: &ArrayWire, + ) -> >::Wires { + let diff_target = b.add_virtual_target(); + let end_padding = b.add(a.real_len, diff_target); + let one = b.one(); + let end_padding = b.sub(end_padding, one); // inclusive range + // little endian so we start padding from the end of the byte + let single_pad = b.constant(F::from_canonical_usize(0x81)); // 1000 0001 + let begin_pad = b.constant(F::from_canonical_usize(0x01)); // 0000 0001 + let end_pad = b.constant(F::from_canonical_usize(0x80)); // 1000 0000 + // TODO : make that const generic + let padded_node = a + .arr + .iter() + .enumerate() + .map(|(i, byte)| { + let i_target = b.constant(F::from_canonical_usize(i)); + // condition if we are within the data range ==> i < length + let is_data = less_than(b, i_target, a.real_len, 32); + // condition if we start the padding ==> i == length + let is_start_padding = b.is_equal(i_target, a.real_len); + // condition if we are done with the padding ==> i == length + diff - 1 + let is_end_padding = b.is_equal(i_target, end_padding); + // condition if we only need to add one byte 1000 0001 to pad + // because we work on u8 data, we know we're at least adding 1 byte and in + // this case it's 0x81 = 1000 0001 + // i == length == diff - 1 + let is_start_and_end = b.and(is_start_padding, is_end_padding); + + // nikko XXX: Is this sound ? I think so but not 100% sure. + // I think it's ok to not use `quin_selector` or `b.random_acess` because + // if the prover gives another byte target, then the resulting hash would be invalid, + let item_data = b.mul(is_data.target, *byte); + let item_start_padding = b.mul(is_start_padding.target, begin_pad); + let item_end_padding = b.mul(is_end_padding.target, end_pad); + let item_start_and_end = b.mul(is_start_and_end.target, single_pad); + // if all of these conditions are false, then item will be 0x00,i.e. the padding + let mut item = item_data; + item = b.add(item, item_start_padding); + item = b.add(item, item_end_padding); + item = b.add(item, item_start_and_end); + item + }) + .collect::>(); + + // convert padded node to u32 + let node_u32_target: Vec = convert_u8_to_u32(b, &padded_node); + + // fixed size block delimitation: this is where we tell the hash function gadget + // to only look at a certain portion of our data, each bool says if the hash function + // will update its state for this block or not. + let rate_bytes = b.constant(F::from_canonical_usize(KECCAK256_R / 8)); + let end_padding_offset = b.add(end_padding, one); + let nb_blocks = b.div(end_padding_offset, rate_bytes); + // - 1 because keccak always take first block so we don't count it + let nb_actual_blocks = b.sub(nb_blocks, one); + let total_num_blocks = N / (KECCAK256_R / 8) - 1; + let blocks = (0..total_num_blocks) + .map(|i| { + let i_target = b.constant(F::from_canonical_usize(i)); + less_than(b, i_target, nb_actual_blocks, 8) + }) + .collect::>(); + + let hash_target = HashInputTarget { + input: BigUintTarget { + limbs: node_u32_target, + }, + input_bits: 0, + blocks, + }; + + let hash_output = b.hash_keccak256(&hash_target); + let output_array: [Target; 8] = hash_output + .limbs + .iter() + .map(|limb| limb.0) + .collect::>() + .try_into() + .expect("keccak256 should have 8 u32 limbs"); + KeccakWires { + input_array: a.clone(), + diff: diff_target, + output_array, + } + } + fn prove_from_array( + pw: &mut PartialWitness, + wires: &KeccakWires, + unpadded_len: usize, + ) { + let diff = compute_padding_size(unpadded_len); + pw.set_target(wires.diff, F::from_canonical_usize(diff)); + } +} + +impl UserCircuit for KeccakCircuit +where + F: RichField + Extendable, +{ + type Wires = KeccakWires; + + fn build(b: &mut CircuitBuilder) -> Self::Wires { + let real_len = b.add_virtual_target(); + let array = b.add_virtual_target_arr::(); + Self::build_from_array( + b, + &ArrayWire { + arr: array, + real_len, + }, + ) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + pw.set_int_targets(&wires.input_array.arr, &self.data); + pw.set_target( + wires.input_array.real_len, + F::from_canonical_usize(self.unpadded_len), + ); + Self::prove_from_array(pw, wires, self.unpadded_len); + } +} + +#[cfg(test)] +mod test { + use super::KeccakCircuit; + use crate::circuit::{PCDCircuit, ProofOrDummyTarget, UserCircuit}; + use plonky2::{ + field::extension::Extendable, hash::hash_types::RichField, + plonk::circuit_builder::CircuitBuilder, + }; + + impl PCDCircuit + for KeccakCircuit + where + F: RichField + Extendable, + { + fn build_recursive( + b: &mut CircuitBuilder, + _: &[ProofOrDummyTarget; ARITY], + ) -> Self::Wires { + let wires = >::build(b); + b.register_public_inputs(&wires.output_array); + wires + // TODO: check the proof public input match what is in the hash node for example for MPT + } + fn base_inputs(&self) -> Vec { + // since we don't care about the public inputs of the first + // proof (since we're not reading them , because we take array + // to hash as witness) + // 8 * u32 = 256 bits + F::rand_vec(8) + } + fn num_io() -> usize { + 8 + } + } +} diff --git a/src/lib.rs b/src/lib.rs index d41609f0b..037cf2053 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ pub mod eth; mod circuit; mod hash; +mod keccak; mod rlp; pub mod transaction; mod utils;