diff --git a/mp2-common/src/array.rs b/mp2-common/src/array.rs index 624e5e1bf..38872794c 100644 --- a/mp2-common/src/array.rs +++ b/mp2-common/src/array.rs @@ -638,7 +638,7 @@ where let (low_bits, high_bits) = b.split_low_high(at, 6, 12); // Search each of the smaller arrays for the target at `low_bits` - let first_search = arrays + let mut first_search = arrays .into_iter() .map(|array| { b.random_access( @@ -652,6 +652,10 @@ where }) .collect::>(); + // Now we push a number of zero targets into the array to make it a power of 2 + let next_power_of_two = first_search.len().next_power_of_two(); + let zero_target = b.zero(); + first_search.resize(next_power_of_two, zero_target); // Serach the result for the Target at `high_bits` T::from_target(b.random_access(high_bits, first_search)) } @@ -683,7 +687,7 @@ where let i_target = b.constant(F::from_canonical_usize(i)); let i_plus_n_target = b.add(at, i_target); - // out_val = arr[((i+n)<=n+M) * (i+n)] + self.random_access_large_array(b, i_plus_n_target) }), } diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index 1155f88b0..f01ec5176 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -467,23 +467,42 @@ impl ReceiptQuery { .into_iter() .map(|index| { let key = index.rlp_bytes(); + let index_size = key.len(); - let proof = block_util.receipts_trie.get_proof(&key)?; + + let proof = block_util.receipts_trie.get_proof(&key[..])?; + + // Since the compact encoding of the key is stored first plus an additional list header and + // then the first element in the receipt body is the transaction type we calculate the offset to that point + + let last_node = proof.last().ok_or(eth_trie::TrieError::DB( + "Could not get last node in proof".to_string(), + ))?; + + let list_length_hint = last_node[0] as usize - 247; + let key_length = if last_node[1 + list_length_hint] > 128 { + last_node[1 + list_length_hint] as usize - 128 + } else { + 0 + }; + let body_length_hint = last_node[2 + list_length_hint + key_length] as usize - 183; + let body_offset = 4 + list_length_hint + key_length + body_length_hint; + let receipt = block_util.txs[index as usize].receipt(); - let rlp_body = receipt.encoded_2718(); - // Skip the first byte as it refers to the transaction type - let length_hint = rlp_body[1] as usize - 247; - let status_offset = 2 + length_hint; - let gas_hint = rlp_body[3 + length_hint] as usize - 128; + let body_length_hint = last_node[body_offset] as usize - 247; + let length_hint = body_offset + body_length_hint; + + let status_offset = 1 + length_hint; + let gas_hint = last_node[2 + length_hint] as usize - 128; // Logs bloom is always 256 bytes long and comes after the gas used the first byte is 185 then 1 then 0 then the bloom so the // log data starts at 4 + length_hint + gas_hint + 259 - let log_offset = 4 + length_hint + gas_hint + 259; + let log_offset = 3 + length_hint + gas_hint + 259; - let log_hint = if rlp_body[log_offset] < 247 { - rlp_body[log_offset] as usize - 192 + let log_hint = if last_node[log_offset] < 247 { + last_node[log_offset] as usize - 192 } else { - rlp_body[log_offset] as usize - 247 + last_node[log_offset] as usize - 247 }; // We iterate through the logs and store the offsets we care about. let mut current_log_offset = log_offset + 1 + log_hint; @@ -632,11 +651,7 @@ impl BlockUtil { let body_rlp = receipt_primitive.encoded_2718(); let tx_body_rlp = transaction_primitive.encoded_2718(); - println!( - "TX index {} RLP encoded: {:?}", - receipt.transaction_index.unwrap(), - tx_index.to_vec() - ); + receipts_trie .insert(&tx_index, &body_rlp) .expect("can't insert receipt"); @@ -646,6 +661,8 @@ impl BlockUtil { TxWithReceipt(transaction.clone(), receipt_primitive) }) .collect::>(); + receipts_trie.root_hash()?; + transactions_trie.root_hash()?; Ok(BlockUtil { block, txs: consensus_receipts, @@ -700,11 +717,10 @@ mod tryethers { use ethers::{ providers::{Http, Middleware, Provider}, types::{ - Address, Block, BlockId, Bytes, EIP1186ProofResponse, Transaction, TransactionReceipt, - H256, U64, + Block, BlockId, Bytes, EIP1186ProofResponse, Transaction, TransactionReceipt, H256, U64, }, }; - use rlp::{Encodable, Rlp, RlpStream}; + use rlp::{Encodable, RlpStream}; /// A wrapper around a transaction and its receipt. The receipt is used to filter /// bad transactions, so we only compute over valid transactions. @@ -851,8 +867,8 @@ mod test { use alloy::{ node_bindings::Anvil, - primitives::{Bytes, Log}, - providers::ProviderBuilder, + primitives::{Bytes, Log, U256}, + providers::{ext::AnvilApi, Provider, ProviderBuilder, WalletProvider}, rlp::Decodable, sol, }; @@ -863,10 +879,10 @@ mod test { types::BlockNumber, }; use hashbrown::HashMap; + use tokio::task::JoinSet; use crate::{ mpt_sequential::utils::nibbles_to_bytes, - types::MAX_BLOCK_LEN, utils::{Endianness, Packer}, }; use mp2_test::eth::{get_mainnet_url, get_sepolia_url}; @@ -1002,14 +1018,11 @@ mod test { #[tokio::test] async fn test_receipt_query() -> Result<()> { - // Spin up a local node. - let anvil = Anvil::new().spawn(); - // Create a provider with the wallet for contract deployment and interaction. - let rpc_url = anvil.endpoint(); + let rpc = ProviderBuilder::new() + .with_recommended_fillers() + .on_anvil_with_wallet_and_config(|anvil| Anvil::block_time(anvil, 1)); - let rpc = ProviderBuilder::new().on_http(rpc_url.parse().unwrap()); - - // Make a contract taht emits events so we can pick up on them + // Make a contract that emits events so we can pick up on them sol! { #[allow(missing_docs)] // solc v0.8.26; solc Counter.sol --via-ir --optimize --bin @@ -1036,84 +1049,108 @@ mod test { } } // Deploy the contract using anvil - let contract = EventEmitter::deploy(&rpc).await?; + let contract = EventEmitter::deploy(rpc.clone()).await?; // Fire off a few transactions to emit some events - let mut transactions = Vec::::new(); - - for i in 0..10 { - if i % 2 == 0 { - let builder = contract.testEmit(); - let tx_hash = builder.send().await?.watch().await?; - let transaction = rpc.get_transaction_by_hash(tx_hash).await?.unwrap(); - transactions.push(transaction); - } else { - let builder = contract.twoEmits(); - let tx_hash = builder.send().await?.watch().await?; - let transaction = rpc.get_transaction_by_hash(tx_hash).await?.unwrap(); - transactions.push(transaction); - } + + let address = rpc.default_signer_address(); + rpc.anvil_set_nonce(address, U256::from(0)).await.unwrap(); + let tx_reqs = (0..10) + .map(|i| match i % 2 { + 0 => contract + .testEmit() + .into_transaction_request() + .nonce(i as u64), + 1 => contract + .twoEmits() + .into_transaction_request() + .nonce(i as u64), + _ => unreachable!(), + }) + .collect::>(); + let mut join_set = JoinSet::new(); + tx_reqs.into_iter().for_each(|tx_req| { + let rpc_clone = rpc.clone(); + join_set.spawn(async move { + rpc_clone + .send_transaction(tx_req) + .await + .unwrap() + .watch() + .await + .unwrap() + }); + }); + + let hashes = join_set.join_all().await; + let mut transactions = Vec::new(); + for hash in hashes.into_iter() { + transactions.push(rpc.get_transaction_by_hash(hash).await.unwrap().unwrap()); } + let block_number = transactions.first().unwrap().block_number.unwrap(); + // We want to get the event signature so we can make a ReceiptQuery let all_events = EventEmitter::abi::events(); let events = all_events.get("testEvent").unwrap(); let receipt_query = ReceiptQuery::new(*contract.address(), events[0].clone()); - // Now for each transaction we fetch the block, then get the MPT Trie proof that the receipt is included and verify it - for transaction in transactions.iter() { - let index = transaction - .block_number - .ok_or(anyhow!("Could not get block number from transaction"))?; - let block = rpc - .get_block( - BlockNumberOrTag::Number(index).into(), - alloy::rpc::types::BlockTransactionsKind::Full, - ) - .await? - .ok_or(anyhow!("Could not get block test"))?; - let proofs = receipt_query - .query_receipt_proofs(&rpc, BlockNumberOrTag::Number(index)) - .await?; - - for proof in proofs.into_iter() { - let memdb = Arc::new(MemoryDB::new(true)); - let tx_trie = EthTrie::new(Arc::clone(&memdb)); + let block = rpc + .get_block( + BlockNumberOrTag::Number(block_number).into(), + alloy::rpc::types::BlockTransactionsKind::Full, + ) + .await? + .ok_or(anyhow!("Could not get block test"))?; + let receipt_hash = block.header().receipts_root; + let proofs = receipt_query + .query_receipt_proofs(&rpc.root(), BlockNumberOrTag::Number(block_number)) + .await?; - let mpt_key = transaction.transaction_index.unwrap().rlp_bytes(); - let receipt_hash = block.header().receipts_root; - let is_valid = tx_trie - .verify_proof(receipt_hash.0.into(), &mpt_key, proof.mpt_proof.clone())? - .ok_or(anyhow!("No proof found when verifying"))?; + // Now for each transaction we fetch the block, then get the MPT Trie proof that the receipt is included and verify it - let expected_sig: [u8; 32] = keccak256(receipt_query.event.signature().as_bytes()) + for proof in proofs.iter() { + let memdb = Arc::new(MemoryDB::new(true)); + let tx_trie = EthTrie::new(Arc::clone(&memdb)); + + let mpt_key = proof.tx_index.rlp_bytes(); + + let _ = tx_trie + .verify_proof(receipt_hash.0.into(), &mpt_key, proof.mpt_proof.clone())? + .ok_or(anyhow!("No proof found when verifying"))?; + + let last_node = proof + .mpt_proof + .last() + .ok_or(anyhow!("Couldn't get first node in proof"))?; + let expected_sig: [u8; 32] = keccak256(receipt_query.event.signature().as_bytes()) + .try_into() + .unwrap(); + + for log_offset in proof.relevant_logs_offset.iter() { + let mut buf = &last_node[*log_offset..*log_offset + proof.event_log_info.size]; + let decoded_log = Log::decode(&mut buf)?; + let raw_bytes: [u8; 20] = last_node[*log_offset + + proof.event_log_info.add_rel_offset + ..*log_offset + proof.event_log_info.add_rel_offset + 20] + .to_vec() .try_into() .unwrap(); - - for log_offset in proof.relevant_logs_offset.iter() { - let mut buf = &is_valid[*log_offset..*log_offset + proof.event_log_info.size]; - let decoded_log = Log::decode(&mut buf)?; - let raw_bytes: [u8; 20] = is_valid[*log_offset - + proof.event_log_info.add_rel_offset - ..*log_offset + proof.event_log_info.add_rel_offset + 20] - .to_vec() - .try_into() - .unwrap(); - assert_eq!(decoded_log.address, receipt_query.contract); - assert_eq!(raw_bytes, receipt_query.contract); - let topics = decoded_log.topics(); - assert_eq!(topics[0].0, expected_sig); - let raw_bytes: [u8; 32] = is_valid[*log_offset - + proof.event_log_info.sig_rel_offset - ..*log_offset + proof.event_log_info.sig_rel_offset + 32] - .to_vec() - .try_into() - .unwrap(); - assert_eq!(topics[0].0, raw_bytes); - } + assert_eq!(decoded_log.address, receipt_query.contract); + assert_eq!(raw_bytes, receipt_query.contract); + let topics = decoded_log.topics(); + assert_eq!(topics[0].0, expected_sig); + let raw_bytes: [u8; 32] = last_node[*log_offset + + proof.event_log_info.sig_rel_offset + ..*log_offset + proof.event_log_info.sig_rel_offset + 32] + .to_vec() + .try_into() + .unwrap(); + assert_eq!(topics[0].0, raw_bytes); } } + Ok(()) } diff --git a/mp2-common/src/group_hashing/mod.rs b/mp2-common/src/group_hashing/mod.rs index 05c0d34ca..819eb7c2b 100644 --- a/mp2-common/src/group_hashing/mod.rs +++ b/mp2-common/src/group_hashing/mod.rs @@ -21,8 +21,6 @@ use plonky2_ecgfp5::{ }, }; -use std::array::from_fn as create_array; - mod curve_add; pub mod field_to_curve; mod sswu_gadget; diff --git a/mp2-common/src/mpt_sequential/leaf_or_extension.rs b/mp2-common/src/mpt_sequential/leaf_or_extension.rs index 8c64d7584..e5c0cf482 100644 --- a/mp2-common/src/mpt_sequential/leaf_or_extension.rs +++ b/mp2-common/src/mpt_sequential/leaf_or_extension.rs @@ -1,6 +1,8 @@ //! MPT leaf or extension node gadget -use super::{advance_key_leaf_or_extension, key::MPTKeyWireGeneric, PAD_LEN}; +use super::{ + advance_key_leaf_or_extension, advance_key_receipt_leaf, key::MPTKeyWireGeneric, PAD_LEN, +}; use crate::{ array::{Array, Vector, VectorWire}, keccak::{InputData, KeccakCircuit, KeccakWires}, @@ -96,3 +98,61 @@ impl MPTLeafOrExtensionNodeGeneric { } } } + +/// Wrapped wires for a MPT receipt leaf +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MPTReceiptLeafWiresGeneric +where + [(); PAD_LEN(NODE_LEN)]:, +{ + /// MPT node + pub node: VectorWire, + /// MPT root + pub root: KeccakWires<{ PAD_LEN(NODE_LEN) }>, + /// New MPT key after advancing the current key + pub key: MPTKeyWireGeneric, +} + +/// Receipt leaf node as we have to do things differently for efficiency reasons. +pub struct MPTReceiptLeafNode; + +impl MPTReceiptLeafNode { + /// Build the MPT node and advance the current key. + pub fn build_and_advance_key< + F: RichField + Extendable, + const D: usize, + const NODE_LEN: usize, + >( + b: &mut CircuitBuilder, + current_key: &MPTKeyWireGeneric, + ) -> MPTReceiptLeafWiresGeneric + where + [(); PAD_LEN(NODE_LEN)]:, + { + let zero = b.zero(); + let tru = b._true(); + + // Build the node and ensure it only includes bytes. + let node = VectorWire::::new(b); + + node.assert_bytes(b); + + // Expose the keccak root of this subtree starting at this node. + let root = KeccakCircuit::<{ PAD_LEN(NODE_LEN) }>::hash_vector(b, &node); + + // We know that the rlp encoding of the compact encoding of the key is going to be in roughly the first 10 bytes of + // the node since the node is list byte, 2 bytes for list length (maybe 3), key length byte (1), key compact encoding (4 max) + // so we take 10 bytes to be safe since this won't effect the number of random access gates we use. + let rlp_headers = decode_fixed_list::<_, D, 1>(b, &node.arr.arr[..10], zero); + + let (key, valid) = advance_key_receipt_leaf::( + b, + &node, + current_key, + &rlp_headers, + ); + b.connect(tru.target, valid.target); + + MPTReceiptLeafWiresGeneric { node, root, key } + } +} diff --git a/mp2-common/src/mpt_sequential/mod.rs b/mp2-common/src/mpt_sequential/mod.rs index 3c6dd8be4..e4518401a 100644 --- a/mp2-common/src/mpt_sequential/mod.rs +++ b/mp2-common/src/mpt_sequential/mod.rs @@ -38,7 +38,7 @@ pub use key::{ }; pub use leaf_or_extension::{ MPTLeafOrExtensionNode, MPTLeafOrExtensionNodeGeneric, MPTLeafOrExtensionWires, - MPTLeafOrExtensionWiresGeneric, + MPTLeafOrExtensionWiresGeneric, MPTReceiptLeafNode, MPTReceiptLeafWiresGeneric, }; /// Number of items in the RLP encoded list in a leaf node. @@ -52,7 +52,7 @@ pub const MAX_LEAF_VALUE_LEN: usize = 33; /// This is the maximum size we allow for the value of Receipt Trie leaf /// currently set to be the same as we allow for a branch node in the Storage Trie /// minus the length of the key header and key -pub const MAX_RECEIPT_LEAF_VALUE_LEN: usize = 526; +pub const MAX_RECEIPT_LEAF_VALUE_LEN: usize = 503; /// RLP item size for the extension node pub const MPT_EXTENSION_RLP_SIZE: usize = 2; @@ -443,6 +443,44 @@ pub fn advance_key_leaf_or_extension< let condition = b.and(condition, should_true); (new_key, leaf_child_hash, condition) } + +/// Returns the key with the pointer moved in the case of a Receipt Trie leaf. +pub fn advance_key_receipt_leaf< + F: RichField + Extendable, + const D: usize, + const NODE_LEN: usize, + const KEY_LEN: usize, +>( + b: &mut CircuitBuilder, + node: &VectorWire, + key: &MPTKeyWireGeneric, + rlp_headers: &RlpList<1>, +) -> (MPTKeyWireGeneric, BoolTarget) { + let key_header = RlpHeader { + data_type: rlp_headers.data_type[0], + offset: rlp_headers.offset[0], + len: rlp_headers.len[0], + }; + + // To save on operations we know the key is goin to be in the first 10 items so we + // only feed these into `decode_compact_encoding` + let sub_array: Array = Array { + arr: create_array(|i| node.arr.arr[i]), + }; + let (extracted_key, should_true) = + decode_compact_encoding::<_, _, _, KEY_LEN>(b, &sub_array, &key_header); + + // note we are going _backwards_ on the key, so we need to substract the expected key length + // we want to check against + let new_key = key.advance_by(b, extracted_key.real_len); + // NOTE: there is no need to check if the extracted_key is indeed a subvector of the full key + // in this case. Indeed, in leaf/ext. there is only one key possible. Since we decoded it + // from the beginning of the node, and that the hash of the node also starts at the beginning, + // either the attacker give the right node or it gives an invalid node and hashes will not + // match. + + (new_key, should_true) +} #[cfg(test)] mod test { use std::array::from_fn as create_array; diff --git a/mp2-test/src/circuit.rs b/mp2-test/src/circuit.rs index 262d4384e..588197b42 100644 --- a/mp2-test/src/circuit.rs +++ b/mp2-test/src/circuit.rs @@ -102,9 +102,9 @@ pub fn prove_circuit< let mut pw = PartialWitness::new(); println!("[+] Generating a proof ... "); - let now = std::time::Instant::now(); u.prove(&mut pw, &setup.0); let proof = setup.1.prove(pw).expect("invalid proof"); + println!("[+] Proof generated in {:?}ms", now.elapsed().as_millis()); setup .2 @@ -124,6 +124,7 @@ pub fn run_circuit< u: U, ) -> ProofWithPublicInputs { let setup = setup_circuit::(); + println!( "setup.verifierdata hash {:?}", setup.2.verifier_only.circuit_digest @@ -131,3 +132,100 @@ pub fn run_circuit< prove_circuit(&setup, &u) } + +/// Given a `PartitionWitness` that has only inputs set, populates the rest of the witness using the +/// given set of generators. +pub fn debug_generate_partial_witness< + 'a, + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + inputs: PartialWitness, + prover_data: &'a plonky2::plonk::circuit_data::ProverOnlyCircuitData, + common_data: &'a plonky2::plonk::circuit_data::CommonCircuitData, +) -> plonky2::iop::witness::PartitionWitness<'a, F> { + use plonky2::iop::witness::WitnessWrite; + + let config = &common_data.config; + let generators = &prover_data.generators; + let generator_indices_by_watches = &prover_data.generator_indices_by_watches; + + let mut witness = plonky2::iop::witness::PartitionWitness::new( + config.num_wires, + common_data.degree(), + &prover_data.representative_map, + ); + + for (t, v) in inputs.target_values.into_iter() { + witness.set_target(t, v); + } + + // Build a list of "pending" generators which are queued to be run. Initially, all generators + // are queued. + let mut pending_generator_indices: Vec<_> = (0..generators.len()).collect(); + + // We also track a list of "expired" generators which have already returned false. + let mut generator_is_expired = vec![false; generators.len()]; + let mut remaining_generators = generators.len(); + + let mut buffer = plonky2::iop::generator::GeneratedValues::empty(); + + // Keep running generators until we fail to make progress. + while !pending_generator_indices.is_empty() { + let mut next_pending_generator_indices = Vec::new(); + + for &generator_idx in &pending_generator_indices { + if generator_is_expired[generator_idx] { + continue; + } + + let finished = generators[generator_idx].0.run(&witness, &mut buffer); + if finished { + generator_is_expired[generator_idx] = true; + remaining_generators -= 1; + } + + // Merge any generated values into our witness, and get a list of newly-populated + // targets' representatives. + let new_target_reps = buffer + .target_values + .drain(..) + .flat_map(|(t, v)| witness.set_target_returning_rep(t, v)); + + // Enqueue unfinished generators that were watching one of the newly populated targets. + for watch in new_target_reps { + let opt_watchers = generator_indices_by_watches.get(&watch); + if let Some(watchers) = opt_watchers { + for &watching_generator_idx in watchers { + if !generator_is_expired[watching_generator_idx] { + next_pending_generator_indices.push(watching_generator_idx); + } + } + } + } + } + + pending_generator_indices = next_pending_generator_indices; + } + if remaining_generators != 0 { + println!("{} generators weren't run", remaining_generators); + + let filtered = generator_is_expired + .iter() + .enumerate() + .filter_map(|(index, flag)| if !flag { Some(index) } else { None }) + .min(); + + if let Some(min_val) = filtered { + println!("generator at index: {} is the first to not run", min_val); + println!("This has ID: {}", generators[min_val].0.id()); + + for watch in generators[min_val].0.watch_list().iter() { + println!("watching: {:?}", watch); + } + } + } + + witness +} diff --git a/mp2-test/src/mpt_sequential.rs b/mp2-test/src/mpt_sequential.rs index d1e79caa1..570170235 100644 --- a/mp2-test/src/mpt_sequential.rs +++ b/mp2-test/src/mpt_sequential.rs @@ -2,8 +2,7 @@ use alloy::{ eips::BlockNumberOrTag, node_bindings::Anvil, primitives::U256, - providers::{ext::AnvilApi, Provider, ProviderBuilder, RootProvider, WalletProvider}, - rpc::types::Transaction, + providers::{ext::AnvilApi, Provider, ProviderBuilder, WalletProvider}, sol, }; use eth_trie::{EthTrie, MemoryDB, Trie}; @@ -53,7 +52,7 @@ pub fn generate_random_storage_mpt( /// This function is used so that we can generate a Receipt Trie for a blog with varying transactions /// (i.e. some we are interested in and some we are not). -fn generate_receipt_proofs() -> Vec { +pub fn generate_receipt_proofs() -> Vec { // Make a contract that emits events so we can pick up on them sol! { #[allow(missing_docs)] @@ -179,15 +178,3 @@ fn generate_receipt_proofs() -> Vec { .unwrap() }) } - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn tester() { - let receipt_proofs = generate_receipt_proofs(); - for proof in receipt_proofs.iter() { - println!("proof: {}", proof.tx_index); - } - } -} diff --git a/mp2-v1/src/lib.rs b/mp2-v1/src/lib.rs index b760a714b..152295ed8 100644 --- a/mp2-v1/src/lib.rs +++ b/mp2-v1/src/lib.rs @@ -14,6 +14,7 @@ pub const MAX_BRANCH_NODE_LEN_PADDED: usize = PAD_LEN(532); pub const MAX_EXTENSION_NODE_LEN: usize = 69; pub const MAX_EXTENSION_NODE_LEN_PADDED: usize = PAD_LEN(69); pub const MAX_LEAF_NODE_LEN: usize = MAX_EXTENSION_NODE_LEN; +pub const MAX_RECEIPT_LEAF_NODE_LEN: usize = 512; pub mod api; pub mod block_extraction; diff --git a/mp2-v1/src/receipt_extraction/leaf.rs b/mp2-v1/src/receipt_extraction/leaf.rs index f7c99d8a7..8fca8a1c5 100644 --- a/mp2-v1/src/receipt_extraction/leaf.rs +++ b/mp2-v1/src/receipt_extraction/leaf.rs @@ -1,17 +1,15 @@ //! Module handling the leaf node inside a Receipt Trie -use super::public_inputs::PublicInputArgs; +use crate::MAX_RECEIPT_LEAF_NODE_LEN; + +use super::public_inputs::{PublicInputArgs, PublicInputs}; use mp2_common::{ array::{Array, Vector, VectorWire}, eth::{EventLogInfo, LogDataInfo, ReceiptProofInfo}, group_hashing::CircuitBuilderGroupHashing, keccak::{InputData, KeccakCircuit, KeccakWires}, - mpt_sequential::{ - MPTLeafOrExtensionNodeGeneric, ReceiptKeyWire, MAX_RECEIPT_LEAF_VALUE_LEN, - MAX_TX_KEY_NIBBLE_LEN, PAD_LEN, - }, - poseidon::H, + mpt_sequential::{MPTReceiptLeafNode, ReceiptKeyWire, MAX_TX_KEY_NIBBLE_LEN, PAD_LEN}, public_inputs::PublicInputCommon, types::{CBuilder, GFp}, utils::{Endianness, PackerTarget}, @@ -23,13 +21,15 @@ use plonky2::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, + plonk::circuit_builder::CircuitBuilder, }; use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use recursion_framework::circuit_builder::CircuitLogicWires; use rlp::Encodable; use serde::{Deserialize, Serialize}; - +use std::array::from_fn; /// Maximum number of logs per transaction we can process const MAX_LOGS_PER_TX: usize = 2; @@ -42,10 +42,10 @@ where pub event: EventWires, /// The node bytes pub node: VectorWire, - /// The actual value stored in the node - pub value: Array, /// the hash of the node bytes pub root: KeccakWires<{ PAD_LEN(NODE_LEN) }>, + /// The index of this receipt in the block + pub index: Target, /// The offset of the status of the transaction in the RLP encoded receipt node. pub status_offset: Target, /// The offsets of the relevant logs inside the node @@ -102,7 +102,7 @@ impl LogColumn { impl EventWires { /// Convert to an array for metadata digest - pub fn to_slice(&self) -> [Target; 70] { + pub fn to_vec(&self) -> Vec { let topics_flat = self .topics .iter() @@ -113,60 +113,45 @@ impl EventWires { .iter() .flat_map(|t| t.to_array()) .collect::>(); - let mut out = [Target::default(); 70]; - out[0] = self.size; - out.iter_mut() - .skip(1) - .take(20) - .enumerate() - .for_each(|(i, entry)| *entry = self.address.arr[i]); - out[21] = self.add_rel_offset; - out.iter_mut() - .skip(22) - .take(32) - .enumerate() - .for_each(|(i, entry)| *entry = self.event_signature.arr[i]); - out[54] = self.sig_rel_offset; - out.iter_mut() - .skip(55) - .take(9) - .enumerate() - .for_each(|(i, entry)| *entry = topics_flat[i]); - out.iter_mut() - .skip(64) - .take(6) - .enumerate() - .for_each(|(i, entry)| *entry = data_flat[i]); + let mut out = Vec::new(); + out.push(self.size); + out.extend_from_slice(&self.address.arr); + out.push(self.add_rel_offset); + out.extend_from_slice(&self.event_signature.arr); + out.push(self.sig_rel_offset); + out.extend_from_slice(&topics_flat); + out.extend_from_slice(&data_flat); + out } - pub fn verify_logs_and_extract_values( + pub fn verify_logs_and_extract_values( &self, b: &mut CBuilder, - value: &Array, + value: &VectorWire, status_offset: Target, relevant_logs_offsets: &VectorWire, ) -> CurveTarget { let t = b._true(); let zero = b.zero(); let curve_zero = b.curve_zero(); - let mut value_digest = b.curve_zero(); + let mut points = Vec::new(); // Enforce status is true. - let status = value.random_access_large_array(b, status_offset); + let status = value.arr.random_access_large_array(b, status_offset); b.connect(status, t.target); for log_offset in relevant_logs_offsets.arr.arr { // Extract the address bytes let address_start = b.add(log_offset, self.add_rel_offset); - let address_bytes = value.extract_array_large::<_, _, 20>(b, address_start); + let address_bytes = value.arr.extract_array_large::<_, _, 20>(b, address_start); let address_check = address_bytes.equals(b, &self.address); // Extract the signature bytes let sig_start = b.add(log_offset, self.sig_rel_offset); - let sig_bytes = value.extract_array_large::<_, _, 32>(b, sig_start); + let sig_bytes = value.arr.extract_array_large::<_, _, 32>(b, sig_start); let sig_check = sig_bytes.equals(b, &self.event_signature); @@ -182,7 +167,7 @@ impl EventWires { for &log_column in self.topics.iter().chain(self.data.iter()) { let data_start = b.add(log_offset, log_column.rel_byte_offset); // The data is always 32 bytes long - let data_bytes = value.extract_array_large::<_, _, 32>(b, data_start); + let data_bytes = value.arr.extract_array_large::<_, _, 32>(b, data_start); // Pack the data and get the digest let packed_data = data_bytes.arr.pack(b, Endianness::Big); @@ -197,11 +182,11 @@ impl EventWires { let selector = b.and(dummy_column, dummy); let selected_point = b.select_curve_point(selector, curve_zero, data_digest); - value_digest = b.add_curve_point(&[selected_point, value_digest]); + points.push(selected_point); } } - value_digest + b.add_curve_point(&points) } } @@ -215,7 +200,7 @@ impl ReceiptLeafCircuit where [(); PAD_LEN(NODE_LEN)]:, { - pub fn build_leaf_wires(b: &mut CBuilder) -> ReceiptLeafWires { + pub fn build(b: &mut CBuilder) -> ReceiptLeafWires { // Build the event wires let event_wires = Self::build_event_wires(b); @@ -227,27 +212,24 @@ where let mpt_key = ReceiptKeyWire::new(b); // Build the node wires. - let wires = MPTLeafOrExtensionNodeGeneric::build_and_advance_key::< - _, - D, - NODE_LEN, - MAX_RECEIPT_LEAF_VALUE_LEN, - >(b, &mpt_key); + let wires = MPTReceiptLeafNode::build_and_advance_key::<_, D, NODE_LEN>(b, &mpt_key); + let node = wires.node; let root = wires.root; // For each relevant log in the transaction we have to verify it lines up with the event we are monitoring for - let receipt_body = wires.value; - let mut dv = event_wires.verify_logs_and_extract_values( + let mut dv = event_wires.verify_logs_and_extract_values::( b, - &receipt_body, + &node, status_offset, &relevant_logs_offset, ); + let value_id = b.map_to_curve_point(&[index]); + dv = b.add_curve_point(&[value_id, dv]); - let dm = b.hash_n_to_hash_no_pad::(event_wires.to_slice().to_vec()); + let dm = b.map_to_curve_point(&event_wires.to_vec()); // Register the public inputs PublicInputArgs { @@ -261,8 +243,8 @@ where ReceiptLeafWires { event: event_wires, node, - value: receipt_body, root, + index, status_offset, relevant_logs_offset, mpt_key, @@ -273,24 +255,22 @@ where let size = b.add_virtual_target(); // Packed address - let arr = [b.add_virtual_target(); 20]; - let address = Array::from_array(arr); + let address = Array::::new(b); // relative offset of the address let add_rel_offset = b.add_virtual_target(); // Event signature - let arr = [b.add_virtual_target(); 32]; - let event_signature = Array::from_array(arr); + let event_signature = Array::::new(b); // Signature relative offset let sig_rel_offset = b.add_virtual_target(); // topics - let topics = [Self::build_log_column(b); 3]; + let topics: [LogColumn; 3] = from_fn(|_| Self::build_log_column(b)); // data - let data = [Self::build_log_column(b); 2]; + let data: [LogColumn; 2] = from_fn(|_| Self::build_log_column(b)); EventWires { size, @@ -331,7 +311,7 @@ where &wires.root, &InputData::Assigned(&pad_node), ); - + pw.set_target(wires.index, GFp::from_canonical_u64(self.info.tx_index)); pw.set_target( wires.status_offset, GFp::from_canonical_usize(self.info.status_offset), @@ -406,13 +386,47 @@ where } } +/// Num of children = 0 +impl CircuitLogicWires for ReceiptLeafWires { + type CircuitBuilderParams = (); + + type Inputs = ReceiptLeafCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + + fn circuit_logic( + builder: &mut CircuitBuilder, + _verified_proofs: [&plonky2::plonk::proof::ProofWithPublicInputsTarget; 0], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + ReceiptLeafCircuit::build(builder) + } + + fn assign_input( + &self, + inputs: Self::Inputs, + pw: &mut PartialWitness, + ) -> anyhow::Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; + use crate::receipt_extraction::compute_receipt_leaf_metadata_digest; + use mp2_common::{ + utils::{keccak256, Packer}, + C, + }; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + mpt_sequential::generate_receipt_proofs, + }; #[derive(Clone, Debug)] struct TestReceiptLeafCircuit { c: ReceiptLeafCircuit, - exp_value: Vec, } impl UserCircuit for TestReceiptLeafCircuit @@ -420,91 +434,38 @@ mod tests { [(); PAD_LEN(NODE_LEN)]:, { // Leaf wires + expected extracted value - type Wires = ( - ReceiptLeafWires, - Array, - ); + type Wires = ReceiptLeafWires; fn build(b: &mut CircuitBuilder) -> Self::Wires { - let exp_value = Array::::new(b); - - let leaf_wires = ReceiptLeafCircuit::::build(b); - leaf_wires.value.enforce_equal(b, &exp_value); - - (leaf_wires, exp_value) + ReceiptLeafCircuit::::build(b) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - wires - .1 - .assign_bytes(pw, &self.exp_value.clone().try_into().unwrap()); + self.c.assign(pw, &wires); } } #[test] fn test_leaf_circuit() { - const NODE_LEN: usize = 80; - - let simple_slot = 2_u8; - let slot = StorageSlot::Simple(simple_slot as usize); - let contract_address = Address::from_str(TEST_CONTRACT_ADDRESS).unwrap(); - let chain_id = 10; - let id = identifier_single_var_column(simple_slot, &contract_address, chain_id, vec![]); - - let (mut trie, _) = generate_random_storage_mpt::<3, MAPPING_LEAF_VALUE_LEN>(); - let value = random_vector(MAPPING_LEAF_VALUE_LEN); - let encoded_value: Vec = rlp::encode(&value).to_vec(); - // assert we added one byte of RLP header - assert_eq!(encoded_value.len(), MAPPING_LEAF_VALUE_LEN + 1); - println!("encoded value {:?}", encoded_value); - trie.insert(&slot.mpt_key(), &encoded_value).unwrap(); - trie.root_hash().unwrap(); - - let proof = trie.get_proof(&slot.mpt_key_vec()).unwrap(); - let node = proof.last().unwrap().clone(); - - let c = LeafSingleCircuit:: { - node: node.clone(), - slot: SimpleSlot::new(simple_slot), - id, - }; - let test_circuit = TestLeafSingleCircuit { - c, - exp_value: value.clone(), - }; + const NODE_LEN: usize = 512; + + let receipt_proof_infos = generate_receipt_proofs(); + let info = receipt_proof_infos.first().unwrap().clone(); + let c = ReceiptLeafCircuit:: { info: info.clone() }; + let test_circuit = TestReceiptLeafCircuit { c }; let proof = run_circuit::(test_circuit); let pi = PublicInputs::new(&proof.public_inputs); - + let node = info.mpt_proof.last().unwrap().clone(); + // Check the output hash { let exp_hash = keccak256(&node).pack(Endianness::Little); assert_eq!(pi.root_hash(), exp_hash); } - { - let (key, ptr) = pi.mpt_key_info(); - - let exp_key = slot.mpt_key_vec(); - let exp_key: Vec<_> = bytes_to_nibbles(&exp_key) - .into_iter() - .map(F::from_canonical_u8) - .collect(); - assert_eq!(key, exp_key); - - let leaf_key: Vec> = rlp::decode_list(&node); - let nib = Nibbles::from_compact(&leaf_key[0]); - let exp_ptr = F::from_canonical_usize(MAX_KEY_NIBBLE_LEN - 1 - nib.nibbles().len()); - assert_eq!(exp_ptr, ptr); - } - // Check values digest - { - let exp_digest = compute_leaf_single_values_digest(id, &value); - assert_eq!(pi.values_digest(), exp_digest.to_weierstrass()); - } + // Check metadata digest { - let exp_digest = compute_leaf_single_metadata_digest(id, simple_slot); + let exp_digest = compute_receipt_leaf_metadata_digest(&info.event_log_info); assert_eq!(pi.metadata_digest(), exp_digest.to_weierstrass()); } - assert_eq!(pi.n(), F::ONE); } -} \ No newline at end of file +} diff --git a/mp2-v1/src/receipt_extraction/mod.rs b/mp2-v1/src/receipt_extraction/mod.rs index 6c3803e08..4950aef20 100644 --- a/mp2-v1/src/receipt_extraction/mod.rs +++ b/mp2-v1/src/receipt_extraction/mod.rs @@ -1,2 +1,31 @@ pub mod leaf; pub mod public_inputs; + +use mp2_common::{ + digest::Digest, eth::EventLogInfo, group_hashing::map_to_curve_point, types::GFp, +}; +use plonky2::field::types::Field; + +/// Calculate `metadata_digest = D(key_id || value_id || slot)` for receipt leaf. +pub fn compute_receipt_leaf_metadata_digest(event: &EventLogInfo) -> Digest { + let topics_flat = event + .topics + .iter() + .chain(event.data.iter()) + .flat_map(|t| [t.column_id, t.rel_byte_offset, t.len]) + .collect::>(); + + let mut out = Vec::new(); + out.push(event.size); + out.extend_from_slice(&event.address.0.map(|byte| byte as usize)); + out.push(event.add_rel_offset); + out.extend_from_slice(&event.event_signature.map(|byte| byte as usize)); + out.push(event.sig_rel_offset); + out.extend_from_slice(&topics_flat); + + let data = out + .into_iter() + .map(GFp::from_canonical_usize) + .collect::>(); + map_to_curve_point(&data) +} diff --git a/mp2-v1/src/receipt_extraction/public_inputs.rs b/mp2-v1/src/receipt_extraction/public_inputs.rs index 901fc0b29..7a44ed175 100644 --- a/mp2-v1/src/receipt_extraction/public_inputs.rs +++ b/mp2-v1/src/receipt_extraction/public_inputs.rs @@ -1,14 +1,22 @@ //! Public inputs for Receipt Extraction circuits use mp2_common::{ + array::Array, keccak::{OutputHash, PACKED_HASH_LEN}, mpt_sequential::ReceiptKeyWire, public_inputs::{PublicInputCommon, PublicInputRange}, - types::{CBuilder, CURVE_TARGET_LEN}, + types::{CBuilder, GFp, GFp5, CURVE_TARGET_LEN}, + utils::{convert_point_to_curve_target, convert_slice_to_curve_point, FromTargets}, }; -use plonky2::hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS}; -use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use plonky2::{ + field::{extension::FieldExtension, types::Field}, + iop::target::Target, +}; +use plonky2_ecgfp5::{ + curve::curve::WeierstrassPoint, + gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}, +}; /// The maximum length of a transaction index in a block in nibbles. /// Theoretically a block can have up to 1428 transactions in Ethereum, which takes 3 bytes to represent. @@ -23,7 +31,7 @@ const T_RANGE: PublicInputRange = K_RANGE.end..K_RANGE.end + 1; /// - `DV : Digest[F]` : value digest of all rows to extract const DV_RANGE: PublicInputRange = T_RANGE.end..T_RANGE.end + CURVE_TARGET_LEN; /// - `DM : Digest[F]` : metadata digest to extract -const DM_RANGE: PublicInputRange = DV_RANGE.end..DV_RANGE.end + NUM_HASH_OUT_ELTS; +const DM_RANGE: PublicInputRange = DV_RANGE.end..DV_RANGE.end + CURVE_TARGET_LEN; /// Public inputs for contract extraction #[derive(Clone, Debug)] @@ -35,7 +43,7 @@ pub struct PublicInputArgs<'a> { /// Digest of the values pub(crate) dv: CurveTarget, /// The poseidon hash of the metadata - pub(crate) dm: HashOutTarget, + pub(crate) dm: CurveTarget, } impl<'a> PublicInputCommon for PublicInputArgs<'a> { @@ -48,12 +56,7 @@ impl<'a> PublicInputCommon for PublicInputArgs<'a> { impl<'a> PublicInputArgs<'a> { /// Create a new public inputs. - pub fn new( - h: &'a OutputHash, - k: &'a ReceiptKeyWire, - dv: CurveTarget, - dm: HashOutTarget, - ) -> Self { + pub fn new(h: &'a OutputHash, k: &'a ReceiptKeyWire, dv: CurveTarget, dm: CurveTarget) -> Self { Self { h, k, dv, dm } } } @@ -63,14 +66,105 @@ impl<'a> PublicInputArgs<'a> { self.h.register_as_public_input(cb); self.k.register_as_input(cb); cb.register_curve_public_input(self.dv); - cb.register_public_inputs(&self.dm.elements); + cb.register_curve_public_input(self.dm); } pub fn digest_value(&self) -> CurveTarget { self.dv } - pub fn digest_metadata(&self) -> HashOutTarget { + pub fn digest_metadata(&self) -> CurveTarget { self.dm } } + +/// Public inputs wrapper of any proof generated in this module +#[derive(Clone, Debug)] +pub struct PublicInputs<'a, T> { + pub(crate) proof_inputs: &'a [T], +} + +impl PublicInputs<'_, Target> { + /// Get the merkle hash of the subtree this proof has processed. + pub fn root_hash_target(&self) -> OutputHash { + OutputHash::from_targets(self.root_hash_info()) + } + + /// Get the MPT key defined over the public inputs. + pub fn mpt_key(&self) -> ReceiptKeyWire { + let (key, ptr) = self.mpt_key_info(); + ReceiptKeyWire { + key: Array { + arr: std::array::from_fn(|i| key[i]), + }, + pointer: ptr, + } + } + + /// Get the values digest defined over the public inputs. + pub fn values_digest_target(&self) -> CurveTarget { + convert_point_to_curve_target(self.values_digest_info()) + } + + /// Get the metadata digest defined over the public inputs. + pub fn metadata_digest_target(&self) -> CurveTarget { + convert_point_to_curve_target(self.metadata_digest_info()) + } +} + +impl PublicInputs<'_, GFp> { + /// Get the merkle hash of the subtree this proof has processed. + pub fn root_hash(&self) -> Vec { + let hash = self.root_hash_info(); + hash.iter().map(|t| t.0 as u32).collect() + } + + /// Get the values digest defined over the public inputs. + pub fn values_digest(&self) -> WeierstrassPoint { + let (x, y, is_inf) = self.values_digest_info(); + + WeierstrassPoint { + x: GFp5::from_basefield_array(std::array::from_fn::(|i| x[i])), + y: GFp5::from_basefield_array(std::array::from_fn::(|i| y[i])), + is_inf: is_inf.is_nonzero(), + } + } + + /// Get the metadata digest defined over the public inputs. + pub fn metadata_digest(&self) -> WeierstrassPoint { + let (x, y, is_inf) = self.metadata_digest_info(); + + WeierstrassPoint { + x: GFp5::from_basefield_array(std::array::from_fn::(|i| x[i])), + y: GFp5::from_basefield_array(std::array::from_fn::(|i| y[i])), + is_inf: is_inf.is_nonzero(), + } + } +} + +impl<'a, T: Copy> PublicInputs<'a, T> { + pub(crate) const TOTAL_LEN: usize = DM_RANGE.end; + + pub fn new(proof_inputs: &'a [T]) -> Self { + Self { proof_inputs } + } + + pub fn root_hash_info(&self) -> &[T] { + &self.proof_inputs[H_RANGE] + } + + pub fn mpt_key_info(&self) -> (&[T], T) { + let key = &self.proof_inputs[K_RANGE]; + let ptr = self.proof_inputs[T_RANGE.start]; + + (key, ptr) + } + + pub fn values_digest_info(&self) -> ([T; 5], [T; 5], T) { + convert_slice_to_curve_point(&self.proof_inputs[DV_RANGE]) + } + + pub fn metadata_digest_info(&self) -> ([T; 5], [T; 5], T) { + convert_slice_to_curve_point(&self.proof_inputs[DM_RANGE]) + } +} diff --git a/mp2-v1/src/values_extraction/api.rs b/mp2-v1/src/values_extraction/api.rs index cbd810010..1ecef4600 100644 --- a/mp2-v1/src/values_extraction/api.rs +++ b/mp2-v1/src/values_extraction/api.rs @@ -153,7 +153,7 @@ macro_rules! impl_branch_circuits { pub type $struct_name = [< $struct_name GenericNodeLen>]; impl $struct_name { - fn new(builder: &CircuitWithUniversalVerifierBuilder) -> Self { + pub fn new(builder: &CircuitWithUniversalVerifierBuilder) -> Self { $struct_name { $( // generate one circuit with full node len @@ -162,7 +162,7 @@ macro_rules! impl_branch_circuits { } } /// Returns the set of circuits to be fed to the recursive framework - fn circuit_set(&self) -> Vec> { + pub fn circuit_set(&self) -> Vec> { let mut arr = Vec::new(); $( arr.push(self.[< b $i >].circuit_data().verifier_only.circuit_digest); @@ -171,7 +171,7 @@ macro_rules! impl_branch_circuits { } /// generates a proof from the inputs stored in `branch`. Depending on the size of the node, /// and the number of children proofs, it selects the right specialized circuit to generate the proof. - fn generate_proof( + pub fn generate_proof( &self, set: &RecursiveCircuits, branch_node: InputNode,