Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace MPT metadata with counter for row_id computation (Part 3) #399

Open
wants to merge 10 commits into
base: generic-extraction-tree-creation
Choose a base branch
from
12 changes: 12 additions & 0 deletions mp2-common/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,15 @@ impl From<HashOut<F>> for HashOutput {
value.to_bytes().try_into().unwrap()
}
}

impl From<HashOutput> for HashOut<F> {
fn from(value: HashOutput) -> Self {
Self::from_bytes(&value.0)
}
}

impl From<&HashOutput> for HashOut<F> {
fn from(value: &HashOutput) -> Self {
Self::from_bytes(&value.0)
}
}
7 changes: 1 addition & 6 deletions mp2-v1/src/values_extraction/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::{
branch::{BranchCircuit, BranchWires},
extension::{ExtensionNodeCircuit, ExtensionNodeWires},
gadgets::{column_info::ColumnInfo, metadata_gadget::MetadataGadget},
gadgets::metadata_gadget::MetadataGadget,
leaf_mapping::{LeafMappingCircuit, LeafMappingWires},
leaf_mapping_of_mappings::{LeafMappingOfMappingsCircuit, LeafMappingOfMappingsWires},
leaf_single::{LeafSingleCircuit, LeafSingleWires},
Expand Down Expand Up @@ -887,7 +887,6 @@ mod tests {
>(table_info.clone());

let values_digest = compute_leaf_single_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value,
Expand All @@ -908,7 +907,6 @@ mod tests {
);

let values_digest = compute_leaf_mapping_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value,
Expand Down Expand Up @@ -936,7 +934,6 @@ mod tests {
>(table_info.clone());

let values_digest = compute_leaf_single_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value,
Expand All @@ -957,7 +954,6 @@ mod tests {
);

let values_digest = compute_leaf_mapping_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value,
Expand Down Expand Up @@ -993,7 +989,6 @@ mod tests {
let values_digest = compute_leaf_mapping_of_mappings_values_digest::<
TEST_MAX_FIELD_PER_EVM,
>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value,
Expand Down
42 changes: 31 additions & 11 deletions mp2-v1/src/values_extraction/gadgets/metadata_gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ pub(crate) struct MetadataTarget<const MAX_COLUMNS: usize, const MAX_FIELD_PER_E
impl<const MAX_COLUMNS: usize, const MAX_FIELD_PER_EVM: usize>
MetadataTarget<MAX_COLUMNS, MAX_FIELD_PER_EVM>
{
/// Compute the metadata digest.
pub(crate) fn digest(&self, b: &mut CBuilder, slot: Target) -> CurveTarget {
/// Compute the metadata digest and number of actual columns.
pub(crate) fn digest_info(&self, b: &mut CBuilder, slot: Target) -> (CurveTarget, Target) {
let zero = b.zero();

let mut partial = b.curve_zero();
let mut non_extracted_column_found = b._false();
let mut num_extracted_columns = b.zero();
let mut num_extracted_columns = zero;
let mut num_actual_columns = zero;

for i in 0..MAX_COLUMNS {
let info = &self.table_info[i];
Expand All @@ -224,11 +227,12 @@ impl<const MAX_COLUMNS: usize, const MAX_FIELD_PER_EVM: usize>
// If the current column has to be extracted, we check that:
// - The EVM word associated to this column is the same as the EVM word we are extracting data from.
// - The slot associated to this column is the same as the slot we are extracting data from.
// - Ensure that we extract only from non-dummy columns.
// if is_extracted:
// evm_word == info.evm_word && slot == info.slot
// evm_word == info.evm_word && slot == info.slot && is_actual
let is_evm_word_eq = b.is_equal(self.evm_word, info.evm_word);
let is_slot_eq = b.is_equal(slot, info.slot);
let acc = [is_extracted, is_evm_word_eq, is_slot_eq]
let acc = [is_extracted, is_actual, is_evm_word_eq, is_slot_eq]
.into_iter()
.reduce(|acc, flag| b.and(acc, flag))
.unwrap();
Expand Down Expand Up @@ -265,6 +269,7 @@ impl<const MAX_COLUMNS: usize, const MAX_FIELD_PER_EVM: usize>
non_extracted_column_found = BoolTarget::new_unsafe(acc);
// num_extracted_columns += is_extracted
num_extracted_columns = b.add(num_extracted_columns, is_extracted.target);
num_actual_columns = b.add(num_actual_columns, is_actual.target);

// Compute the partial digest of all columns.
// mpt_metadata = H(info.slot || info.evm_word || info.byte_offset || info.bit_offset || info.length)
Expand Down Expand Up @@ -295,7 +300,7 @@ impl<const MAX_COLUMNS: usize, const MAX_FIELD_PER_EVM: usize>
less_than_or_equal_to_unsafe(b, num_extracted_columns, max_field_per_evm, 8);
b.assert_one(num_extracted_lt_or_eq_max.target);

partial
(partial, num_actual_columns)
}
}

Expand All @@ -311,32 +316,45 @@ pub(crate) mod tests {
struct TestMedataCircuit {
metadata_gadget: MetadataGadget<TEST_MAX_COLUMNS, TEST_MAX_FIELD_PER_EVM>,
slot: u8,
expected_num_actual_columns: usize,
expected_metadata_digest: Point,
}

impl UserCircuit<F, D> for TestMedataCircuit {
// Metadata target + slot + expected metadata digest
// Metadata target + slot + expected number of actual columns + expected metadata digest
type Wires = (
MetadataTarget<TEST_MAX_COLUMNS, TEST_MAX_FIELD_PER_EVM>,
Target,
Target,
CurveTarget,
);

fn build(b: &mut CBuilder) -> Self::Wires {
let metadata_target = MetadataGadget::build(b);
let slot = b.add_virtual_target();
let expected_num_actual_columns = b.add_virtual_target();
let expected_metadata_digest = b.add_virtual_curve_target();

let metadata_digest = metadata_target.digest(b, slot);
let (metadata_digest, num_actual_columns) = metadata_target.digest_info(b, slot);
b.connect_curve_points(metadata_digest, expected_metadata_digest);

(metadata_target, slot, expected_metadata_digest)
b.connect(num_actual_columns, expected_num_actual_columns);

(
metadata_target,
slot,
expected_num_actual_columns,
expected_metadata_digest,
)
}

fn prove(&self, pw: &mut PartialWitness<F>, wires: &Self::Wires) {
self.metadata_gadget.assign(pw, &wires.0);
pw.set_target(wires.1, F::from_canonical_u8(self.slot));
pw.set_curve_target(wires.2, self.expected_metadata_digest.to_weierstrass());
pw.set_target(
wires.2,
F::from_canonical_usize(self.expected_num_actual_columns),
);
pw.set_curve_target(wires.3, self.expected_metadata_digest.to_weierstrass());
}
}

Expand All @@ -348,11 +366,13 @@ pub(crate) mod tests {
let evm_word = rng.gen();

let metadata_gadget = MetadataGadget::sample(slot, evm_word);
let expected_num_actual_columns = metadata_gadget.num_actual_columns();
let expected_metadata_digest = metadata_gadget.digest();

let test_circuit = TestMedataCircuit {
metadata_gadget,
slot,
expected_num_actual_columns,
expected_metadata_digest,
};

Expand Down
13 changes: 7 additions & 6 deletions mp2-v1/src/values_extraction/leaf_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ where
{
pub fn build(b: &mut CBuilder) -> LeafMappingWires<NODE_LEN, MAX_COLUMNS, MAX_FIELD_PER_EVM> {
let zero = b.zero();
let one = b.one();

let key_id = b.add_virtual_target();
let metadata = MetadataGadget::build(b);
Expand All @@ -99,8 +100,10 @@ where
// Left pad the leaf value.
let value: Array<Target, MAPPING_LEAF_VALUE_LEN> = left_pad_leaf_value(b, &wires.value);

// Compute the metadata digest.
let metadata_digest = metadata.digest(b, slot.mapping_slot);
// Compute the metadata digest and number of actual columns.
let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.mapping_slot);
// We add key column to number of actual columns.
let num_actual_columns = b.add(num_actual_columns, one);

// key_column_md = H( "\0KEY" || slot)
let key_id_prefix = b.constant(F::from_canonical_u32(u32::from_be_bytes(
Expand Down Expand Up @@ -139,11 +142,11 @@ where
// Compute the unique data to identify a row is the mapping key.
// row_unique_data = H(pack(left_pad32(key))
let row_unique_data = b.hash_n_to_hash_no_pad::<CHasher>(packed_mapping_key);
// row_id = H2int(row_unique_data || metadata_digest)
// row_id = H2int(row_unique_data || num_actual_columns)
let inputs = row_unique_data
.to_targets()
.into_iter()
.chain(metadata_digest.to_targets())
.chain(once(num_actual_columns))
.collect();
let hash = b.hash_n_to_hash_no_pad::<CHasher>(inputs);
let row_id = hash_to_int_target(b, hash);
Expand Down Expand Up @@ -222,7 +225,6 @@ where

#[cfg(test)]
mod tests {

use super::*;
use crate::{
tests::{TEST_MAX_COLUMNS, TEST_MAX_FIELD_PER_EVM},
Expand Down Expand Up @@ -309,7 +311,6 @@ mod tests {
>(table_info.clone(), slot, key_id);
// Compute the values digest.
let values_digest = compute_leaf_mapping_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value.clone().try_into().unwrap(),
Expand Down
12 changes: 7 additions & 5 deletions mp2-v1/src/values_extraction/leaf_mapping_of_mappings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ where
b: &mut CBuilder,
) -> LeafMappingOfMappingsWires<NODE_LEN, MAX_COLUMNS, MAX_FIELD_PER_EVM> {
let zero = b.zero();
let two = b.two();

let [outer_key_id, inner_key_id] = b.add_virtual_target_arr();
let metadata = MetadataGadget::build(b);
Expand All @@ -108,8 +109,10 @@ where
// Left pad the leaf value.
let value: Array<Target, MAPPING_LEAF_VALUE_LEN> = left_pad_leaf_value(b, &wires.value);

// Compute the metadata digest.
let metadata_digest = metadata.digest(b, slot.mapping_slot);
// Compute the metadata digest and number of actual columns.
let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.mapping_slot);
// Add inner key and outer key columns to the number of actual columns.
let num_actual_columns = b.add(num_actual_columns, two);

// Compute the outer and inner key metadata digests.
let [outer_key_digest, inner_key_digest] = [
Expand Down Expand Up @@ -173,11 +176,11 @@ where
.chain(packed_inner_key)
.collect();
let row_unique_data = b.hash_n_to_hash_no_pad::<CHasher>(inputs);
// row_id = H2int(row_unique_data || metadata_digest)
// row_id = H2int(row_unique_data || num_actual_columns)
let inputs = row_unique_data
.to_targets()
.into_iter()
.chain(metadata_digest.to_targets())
.chain(once(num_actual_columns))
.collect();
let hash = b.hash_n_to_hash_no_pad::<CHasher>(inputs);
let row_id = hash_to_int_target(b, hash);
Expand Down Expand Up @@ -356,7 +359,6 @@ mod tests {
>(table_info.clone(), slot, outer_key_id, inner_key_id);
// Compute the values digest.
let values_digest = compute_leaf_mapping_of_mappings_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value.clone().try_into().unwrap(),
Expand Down
10 changes: 5 additions & 5 deletions mp2-v1/src/values_extraction/leaf_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative;
use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5;
use recursion_framework::circuit_builder::CircuitLogicWires;
use serde::{Deserialize, Serialize};
use std::iter::once;

#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct LeafSingleWires<
Expand Down Expand Up @@ -83,8 +84,8 @@ where
// Left pad the leaf value.
let value: Array<Target, MAPPING_LEAF_VALUE_LEN> = left_pad_leaf_value(b, &wires.value);

// Compute the metadata digest.
let metadata_digest = metadata.digest(b, slot.base.slot);
// Compute the metadata digest and number of actual columns.
let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.base.slot);

// Compute the values digest.
let values_digest = ColumnGadget::<MAX_FIELD_PER_EVM>::new(
Expand All @@ -94,12 +95,12 @@ where
)
.build(b);

// row_id = H2int(H("") || metadata_digest)
// row_id = H2int(H("") || num_actual_columns)
let empty_hash = b.constant_hash(*empty_poseidon_hash());
let inputs = empty_hash
.to_targets()
.into_iter()
.chain(metadata_digest.to_targets())
.chain(once(num_actual_columns))
.collect();
let hash = b.hash_n_to_hash_no_pad::<CHasher>(inputs);
let row_id = hash_to_int_target(b, hash);
Expand Down Expand Up @@ -253,7 +254,6 @@ mod tests {
let table_info = metadata.actual_table_info().to_vec();
let extracted_column_identifiers = metadata.extracted_column_identifiers();
let values_digest = compute_leaf_single_values_digest::<TEST_MAX_FIELD_PER_EVM>(
&metadata_digest,
table_info,
&extracted_column_identifiers,
value.clone().try_into().unwrap(),
Expand Down
Loading
Loading