Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement provably extract LIMIT/OFFSET circuits (Part1) #288

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
40 changes: 19 additions & 21 deletions verifiable-db/src/results_tree/extraction/public_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub enum ResultsExtractionPublicInputs {
}

#[derive(Clone, Debug)]
pub struct PublicInputs<'a, T, const S: usize> {
pub struct PublicInputs<'a, T> {
h: &'a [T],
min_val: &'a [T],
max_val: &'a [T],
Expand All @@ -64,7 +64,7 @@ pub struct PublicInputs<'a, T, const S: usize> {

const NUM_PUBLIC_INPUTS: usize = ResultsExtractionPublicInputs::Accumulator as usize + 1;

impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> {
impl<'a, T: Clone> PublicInputs<'a, T> {
const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [
Self::to_range(ResultsExtractionPublicInputs::TreeHash),
Self::to_range(ResultsExtractionPublicInputs::MinValue),
Expand Down Expand Up @@ -219,7 +219,7 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> {
}
}

impl<'a, const S: usize> PublicInputCommon for PublicInputs<'a, Target, S> {
impl<'a> PublicInputCommon for PublicInputs<'a, Target> {
const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES;

fn register_args(&self, cb: &mut CBuilder) {
Expand All @@ -236,7 +236,7 @@ impl<'a, const S: usize> PublicInputCommon for PublicInputs<'a, Target, S> {
}
}

impl<'a, const S: usize> PublicInputs<'a, Target, S> {
impl<'a> PublicInputs<'a, Target> {
pub fn tree_hash_target(&self) -> HashOutTarget {
HashOutTarget::try_from(self.to_tree_hash_raw()).unwrap()
}
Expand Down Expand Up @@ -278,7 +278,7 @@ impl<'a, const S: usize> PublicInputs<'a, Target, S> {
}
}

impl<'a, const S: usize> PublicInputs<'a, F, S> {
impl<'a> PublicInputs<'a, F> {
pub fn tree_hash(&self) -> HashOut<F> {
HashOut::try_from(self.to_tree_hash_raw()).unwrap()
}
Expand Down Expand Up @@ -336,7 +336,6 @@ mod tests {
plonk::circuit_builder::CircuitBuilder,
};

const S: usize = 10;
#[derive(Clone, Debug)]
struct TestPublicInputs<'a> {
pis: &'a [F],
Expand All @@ -346,8 +345,8 @@ mod tests {
type Wires = Vec<Target>;

fn build(c: &mut CircuitBuilder<F, D>) -> Self::Wires {
let targets = c.add_virtual_target_arr::<{ PublicInputs::<Target, S>::total_len() }>();
let pi_targets = PublicInputs::<Target, S>::from_slice(targets.as_slice());
let targets = c.add_virtual_target_arr::<{ PublicInputs::<Target>::total_len() }>();
let pi_targets = PublicInputs::<Target>::from_slice(targets.as_slice());
pi_targets.register_args(c);
pi_targets.to_vec()
}
Expand All @@ -359,54 +358,53 @@ mod tests {

#[test]
fn test_results_extraction_public_inputs() {
let pis_raw = random_vector::<u32>(PublicInputs::<F, S>::total_len()).to_fields();
let pis_raw = random_vector::<u32>(PublicInputs::<F>::total_len()).to_fields();

// use public inputs in circuit
let test_circuit = TestPublicInputs { pis: &pis_raw };
let proof = run_circuit::<F, D, C, _>(test_circuit);
assert_eq!(proof.public_inputs, pis_raw);

// check public inputs are constructed correctly
let pis = PublicInputs::<F, S>::from_slice(&proof.public_inputs);
let pis = PublicInputs::<F>::from_slice(&proof.public_inputs);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::TreeHash)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::TreeHash)],
pis.to_tree_hash_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MinValue)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::MinValue)],
pis.to_min_value_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MaxValue)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::MaxValue)],
pis.to_max_value_raw(),
);
assert_eq!(
&pis_raw
[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::PrimaryIndexValue)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::PrimaryIndexValue)],
pis.to_primary_index_value_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::IndexIds)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::IndexIds)],
pis.to_index_ids_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MinCounter)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::MinCounter)],
&[*pis.to_min_counter_raw()],
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MaxCounter)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::MaxCounter)],
&[*pis.to_max_counter_raw()],
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::OffsetRangeMin)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::OffsetRangeMin)],
&[*pis.to_offset_range_min_raw()],
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::OffsetRangeMax)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::OffsetRangeMax)],
&[*pis.to_offset_range_max_raw()],
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::Accumulator)],
&pis_raw[PublicInputs::<F>::to_range(ResultsExtractionPublicInputs::Accumulator)],
pis.to_accumulator_raw(),
);
}
Expand Down
118 changes: 84 additions & 34 deletions verifiable-db/src/results_tree/extraction/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ use serde::{Deserialize, Serialize};
use std::iter;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RecordWires<const MAX_NUM_RESULTS: usize> {
pub struct RecordWires {
#[serde(
serialize_with = "serialize_long_array",
deserialize_with = "deserialize_long_array"
)]
indexed_items: [UInt256Target; MAX_NUM_RESULTS],
indexed_items: [UInt256Target; 2],
#[serde(
serialize_with = "serialize_long_array",
deserialize_with = "deserialize_long_array"
Expand All @@ -44,14 +44,14 @@ pub struct RecordWires<const MAX_NUM_RESULTS: usize> {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RecordCircuit<const MAX_NUM_RESULTS: usize> {
pub struct RecordCircuit {
/// Values of the indexed items for in this record;
/// if there is no secondary indexed item, just place the dummy value `0`
#[serde(
serialize_with = "serialize_long_array",
deserialize_with = "deserialize_long_array"
)]
pub(crate) indexed_items: [U256; MAX_NUM_RESULTS],
pub(crate) indexed_items: [U256; 2],
/// Integer identifiers of the indexed items
#[serde(
serialize_with = "serialize_long_array",
Expand All @@ -72,12 +72,38 @@ pub struct RecordCircuit<const MAX_NUM_RESULTS: usize> {
pub(crate) offset_range_max: F,
}

impl<const MAX_NUM_RESULTS: usize> RecordCircuit<MAX_NUM_RESULTS> {
pub fn build(b: &mut CBuilder) -> RecordWires<MAX_NUM_RESULTS> {
impl RecordCircuit {
pub fn new(
first_indexed_item: U256,
second_indexed_item: Option<U256>,
index_ids: [F; 2],
tree_hash: HashOut<F>,
counter: F,
is_stored_in_leaf: bool,
offset_range_min: F,
offset_range_max: F,
) -> Self {
let indexed_items = [
first_indexed_item,
second_indexed_item.unwrap_or(U256::ZERO),
];

Self {
indexed_items,
index_ids,
tree_hash,
counter,
is_stored_in_leaf,
offset_range_min,
offset_range_max,
}
}

pub fn build(b: &mut CBuilder) -> RecordWires {
let ffalse = b._false();
let empty_hash = b.constant_hash(*empty_poseidon_hash());

let indexed_items: [UInt256Target; MAX_NUM_RESULTS] = b.add_virtual_u256_arr_unsafe();
let indexed_items: [UInt256Target; 2] = b.add_virtual_u256_arr_unsafe();
let index_ids: [Target; 2] = b.add_virtual_target_arr();
let tree_hash = b.add_virtual_hash();
let counter = b.add_virtual_target();
Expand All @@ -104,7 +130,7 @@ impl<const MAX_NUM_RESULTS: usize> RecordCircuit<MAX_NUM_RESULTS> {
.chain(indexed_items[0].to_targets())
.chain(iter::once(index_ids[1]))
.chain(indexed_items[1].to_targets())
.chain(final_tree_hash.to_targets())
.chain(tree_hash.to_targets())
.collect();
let accumulator = b.map_to_curve_point(&accumulator_inputs);

Expand All @@ -118,7 +144,7 @@ impl<const MAX_NUM_RESULTS: usize> RecordCircuit<MAX_NUM_RESULTS> {
b.connect(is_out_of_range.target, ffalse.target);

// Register the public inputs.
PublicInputs::<_, MAX_NUM_RESULTS>::new(
PublicInputs::new(
&final_tree_hash.to_targets(),
&indexed_items[1].to_targets(),
&indexed_items[1].to_targets(),
Expand All @@ -143,7 +169,7 @@ impl<const MAX_NUM_RESULTS: usize> RecordCircuit<MAX_NUM_RESULTS> {
}
}

fn assign(&self, pw: &mut PartialWitness<F>, wires: &RecordWires<MAX_NUM_RESULTS>) {
fn assign(&self, pw: &mut PartialWitness<F>, wires: &RecordWires) {
wires
.indexed_items
.iter()
Expand All @@ -161,13 +187,11 @@ impl<const MAX_NUM_RESULTS: usize> RecordCircuit<MAX_NUM_RESULTS> {
/// Verified proof number = 0
pub(crate) const NUM_VERIFIED_PROOFS: usize = 0;

impl<const MAX_NUM_RESULTS: usize> CircuitLogicWires<F, D, NUM_VERIFIED_PROOFS>
for RecordWires<MAX_NUM_RESULTS>
{
impl CircuitLogicWires<F, D, NUM_VERIFIED_PROOFS> for RecordWires {
type CircuitBuilderParams = ();
type Inputs = RecordCircuit<MAX_NUM_RESULTS>;
type Inputs = RecordCircuit;

const NUM_PUBLIC_INPUTS: usize = PublicInputs::<F, MAX_NUM_RESULTS>::total_len();
const NUM_PUBLIC_INPUTS: usize = PublicInputs::<F>::total_len();

fn circuit_logic(
builder: &mut CBuilder,
Expand All @@ -186,7 +210,7 @@ impl<const MAX_NUM_RESULTS: usize> CircuitLogicWires<F, D, NUM_VERIFIED_PROOFS>
#[cfg(test)]
mod tests {
use super::*;
use mp2_common::{utils::ToFields, C};
use mp2_common::{group_hashing::map_to_curve_point, utils::ToFields, C};
use mp2_test::{
circuit::{run_circuit, UserCircuit},
utils::{gen_random_field_hash, gen_random_u256},
Expand All @@ -195,10 +219,8 @@ mod tests {
use rand::{thread_rng, Rng};
use std::array;

const MAX_NUM_RESULTS: usize = 20;

impl UserCircuit<F, D> for RecordCircuit<MAX_NUM_RESULTS> {
type Wires = RecordWires<MAX_NUM_RESULTS>;
impl UserCircuit<F, D> for RecordCircuit {
type Wires = RecordWires;

fn build(b: &mut CBuilder) -> Self::Wires {
RecordCircuit::build(b)
Expand All @@ -209,30 +231,36 @@ mod tests {
}
}

fn test_record_circuit(is_stored_in_leaf: bool) {
fn test_record_circuit(is_stored_in_leaf: bool, is_second_index_item_dummy: bool) {
// Construct the witness.
let mut rng = thread_rng();
let indexed_items = array::from_fn(|_| gen_random_u256(&mut rng));
let first_indexed_item = gen_random_u256(&mut rng);
let second_indexed_item = if is_second_index_item_dummy {
None
} else {
Some(gen_random_u256(&mut rng))
};
let index_ids = array::from_fn(|_| F::from_canonical_usize(rng.gen()));
let tree_hash = gen_random_field_hash();
let counter = F::from_canonical_u32(rng.gen());
let offset_range_min = counter - F::ONE;
let offset_range_max = counter + F::ONE;

// Construct the circuit.
let test_circuit = RecordCircuit {
indexed_items,
let test_circuit = RecordCircuit::new(
first_indexed_item,
second_indexed_item,
index_ids,
tree_hash,
counter,
is_stored_in_leaf,
offset_range_min,
offset_range_max,
};
);

// Proof for the test circuit.
let proof = run_circuit::<F, D, C, _>(test_circuit);
let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs);
let pi = PublicInputs::from_slice(&proof.public_inputs);

// Check the public inputs.

Expand All @@ -244,10 +272,10 @@ mod tests {
.clone()
.into_iter()
.chain(empty_hash_fields)
.chain(indexed_items[1].to_fields())
.chain(indexed_items[1].to_fields())
.chain(second_indexed_item.unwrap_or(U256::ZERO).to_fields())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: we could replace this with test_circuit.second_indexed_item so that we don't have to do this unwrap_or everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 716e652.

.chain(second_indexed_item.unwrap_or(U256::ZERO).to_fields())
.chain(iter::once(index_ids[1]))
.chain(indexed_items[1].to_fields())
.chain(second_indexed_item.unwrap_or(U256::ZERO).to_fields())
.chain(tree_hash.to_fields())
.collect();
let exp_hash = H::hash_no_pad(&hash_inputs);
Expand All @@ -257,13 +285,13 @@ mod tests {
};

// Min value
assert_eq!(pi.min_value(), indexed_items[1]);
assert_eq!(pi.min_value(), second_indexed_item.unwrap_or(U256::ZERO));

// Max value
assert_eq!(pi.max_value(), indexed_items[1]);
assert_eq!(pi.max_value(), second_indexed_item.unwrap_or(U256::ZERO));

// Primary index value
assert_eq!(pi.primary_index_value(), indexed_items[0]);
assert_eq!(pi.primary_index_value(), first_indexed_item);

// Index ids
assert_eq!(pi.index_ids(), index_ids);
Expand All @@ -279,15 +307,37 @@ mod tests {

// Offset range max
assert_eq!(pi.offset_range_max(), offset_range_max);

// Accumulator
{
let accumulator_inputs: Vec<_> = iter::once(index_ids[0])
.chain(first_indexed_item.to_fields())
.chain(iter::once(index_ids[1]))
.chain(second_indexed_item.unwrap_or(U256::ZERO).to_fields())
.chain(tree_hash.to_fields())
.collect();
let exp_accumulator = map_to_curve_point(&accumulator_inputs);
assert_eq!(pi.accumulator(), exp_accumulator.to_weierstrass());
}
}

#[test]
fn test_record_circuit_storing_in_leaf() {
test_record_circuit(true);
test_record_circuit(true, false);
}

#[test]
fn test_record_circuit_storing_in_inter() {
test_record_circuit(false);
test_record_circuit(false, false);
}

#[test]
fn test_record_circuit_storing_in_leaf_with_dummy_item() {
test_record_circuit(true, true);
}

#[test]
fn test_record_circuit_storing_in_inter_with_dummy_item() {
test_record_circuit(false, true);
}
}
Loading