Skip to content

Commit

Permalink
feat: add public inputs for results extraction circuits
Browse files Browse the repository at this point in the history
  • Loading branch information
Insun35 committed Aug 7, 2024
1 parent 35d699b commit 25d5e4c
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 0 deletions.
1 change: 1 addition & 0 deletions verifiable-db/src/results_tree/extraction/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod public_inputs;
371 changes: 371 additions & 0 deletions verifiable-db/src/results_tree/extraction/public_inputs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
//! The public inputs of a set of circuits to extract the actual results
//! to be returned from the results tree
use itertools::Itertools;
use mp2_common::{
public_inputs::{PublicInputCommon, PublicInputRange},
types::{CBuilder, CURVE_TARGET_LEN},
u256::{UInt256Target, NUM_LIMBS},
utils::FromTargets,
};
use plonky2::{
hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS},
iop::target::Target,
};
use plonky2_ecgfp5::gadgets::curve::CurveTarget;
use std::iter::once;

/// Public inputs of the circuits to extract results from the results tree
pub enum ResultsExtractionPublicInputs {
/// `H`: `hash` Hash of the subtree rooted in the current node
TreeHash,
/// `min`: `u256` Minimum value of the indexed items for the subtree rooted
/// in the current node; it will correspond to the secondary indexed item for nodes
/// of the rows trees, and to the primary indexed item for nodes of the index tree
MinValue,
/// `max`: `u256` Maximum value of the indexed item for the subtree rooted in the current node;
/// it will correspond to the secondary indexed item for nodes of the rows trees,
/// and to the primary indexed item for nodes on the index tree
MaxValue,
/// `I`: `u256` Value of the primary indexed item for the rows stored in the subtree
/// of rows tree in the current node
PrimaryIndexValue,
/// `index_ids`: `[2]F` Integer identifiers of the indexed items
IndexIds,
/// `min_counter`: `u256` Minimum counter across the records in the
/// subtree rooted in the current node
// TODO(Insun35): Should this be `F`?
MinCounter,
/// `max_counter`: `u256` Maximum counter across the records in the
/// subtree rooted in the current node
// TODO(Insun35): Should this be `F`?
MaxCounter,
/// `offset_range_min`: `u256` lower bound of the range `[offset, limit + offset]` derived from the query
OffsetRangeMin,
/// `offset_range_max`: `u256` upper bound of the range `[offset, limit + offset]` derived from the query
OffsetRangeMax,
/// `D`: `Digest` order-agnostic digested employed to accumulate the result to be returned
Accumulator,
}

#[derive(Clone, Debug)]
pub struct PublicInputs<'a, T, const S: usize> {
h: &'a [T],
min_val: &'a [T],
max_val: &'a [T],
pri_idx_val: &'a [T],
idx_ids: &'a [T],
min_cnt: &'a [T],
max_cnt: &'a [T],
offset_range_min: &'a [T],
offset_range_max: &'a [T],
acc: &'a [T],
}

const NUM_PUBLIC_INPUTS: usize = ResultsExtractionPublicInputs::Accumulator as usize + 1;

impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> {
const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [
Self::to_range(ResultsExtractionPublicInputs::TreeHash),
Self::to_range(ResultsExtractionPublicInputs::MinValue),
Self::to_range(ResultsExtractionPublicInputs::MaxValue),
Self::to_range(ResultsExtractionPublicInputs::PrimaryIndexValue),
Self::to_range(ResultsExtractionPublicInputs::IndexIds),
Self::to_range(ResultsExtractionPublicInputs::MinCounter),
Self::to_range(ResultsExtractionPublicInputs::MaxCounter),
Self::to_range(ResultsExtractionPublicInputs::OffsetRangeMin),
Self::to_range(ResultsExtractionPublicInputs::OffsetRangeMax),
Self::to_range(ResultsExtractionPublicInputs::Accumulator),
];

const SIZES: [usize; NUM_PUBLIC_INPUTS] = [
// Tree hash
NUM_HASH_OUT_ELTS,
// Minimum value
NUM_LIMBS,
// Maximum value
NUM_LIMBS,
// Primary index value
NUM_LIMBS,
// Indexed column IDs
2,
// Minimum counter
NUM_LIMBS,
// Maximum counter
NUM_LIMBS,
// Offset Range Min
NUM_LIMBS,
// Offset Range Max
NUM_LIMBS,
// accumulator
CURVE_TARGET_LEN,
];

pub(crate) const fn to_range(pi: ResultsExtractionPublicInputs) -> PublicInputRange {
let mut i = 0;
let mut offset = 0;
let pi_pos = pi as usize;
while i < pi_pos {
offset += Self::SIZES[i];
i += 1;
}
offset..offset + Self::SIZES[pi_pos]
}

pub(crate) const fn total_len() -> usize {
Self::to_range(ResultsExtractionPublicInputs::Accumulator).end
}

pub(crate) fn to_tree_hash_raw(&self) -> &[T] {
self.h
}

pub(crate) fn to_min_value_raw(&self) -> &[T] {
self.min_val
}

pub(crate) fn to_max_value_raw(&self) -> &[T] {
self.max_val
}

pub(crate) fn to_primary_index_value_raw(&self) -> &[T] {
self.pri_idx_val
}

pub(crate) fn to_index_ids_raw(&self) -> &[T] {
self.idx_ids
}

pub(crate) fn to_min_counter_raw(&self) -> &[T] {
self.min_cnt
}

pub(crate) fn to_max_counter_raw(&self) -> &[T] {
self.max_cnt
}

pub(crate) fn to_offset_range_min_raw(&self) -> &[T] {
self.offset_range_min
}

pub(crate) fn to_offset_range_max_raw(&self) -> &[T] {
self.offset_range_max
}

pub(crate) fn to_accumulator_raw(&self) -> &[T] {
self.acc
}

pub fn from_slice(input: &'a [T]) -> Self {
assert!(
input.len() >= Self::total_len(),
"Input slice too short to build results public inputs, must be at least {} elements",
Self::total_len(),
);
Self {
h: &input[Self::PI_RANGES[0].clone()],
min_val: &input[Self::PI_RANGES[1].clone()],
max_val: &input[Self::PI_RANGES[2].clone()],
pri_idx_val: &input[Self::PI_RANGES[3].clone()],
idx_ids: &input[Self::PI_RANGES[4].clone()],
min_cnt: &input[Self::PI_RANGES[5].clone()],
max_cnt: &input[Self::PI_RANGES[6].clone()],
offset_range_min: &input[Self::PI_RANGES[7].clone()],
offset_range_max: &input[Self::PI_RANGES[8].clone()],
acc: &input[Self::PI_RANGES[9].clone()],
}
}

pub fn new(
h: &'a [T],
min_val: &'a [T],
max_val: &'a [T],
pri_idx_val: &'a [T],
idx_ids: &'a [T],
min_cnt: &'a [T],
max_cnt: &'a [T],
offset_range_min: &'a [T],
offset_range_max: &'a [T],
acc: &'a [T],
) -> Self {
Self {
h,
min_val,
max_val,
pri_idx_val,
idx_ids,
min_cnt,
max_cnt,
offset_range_min,
offset_range_max,
acc,
}
}

pub fn to_vec(&self) -> Vec<T> {
self.h
.iter()
.chain(self.min_val.iter())
.chain(self.max_val.iter())
.chain(self.pri_idx_val.iter())
.chain(self.idx_ids.iter())
.chain(self.min_cnt.iter())
.chain(self.max_cnt.iter())
.chain(self.offset_range_min.iter())
.chain(self.offset_range_max.iter())
.chain(self.acc.iter())
.cloned()
.collect_vec()
}
}

impl<'a, const S: usize> PublicInputCommon for PublicInputs<'a, Target, S> {
const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES;

fn register_args(&self, cb: &mut CBuilder) {
cb.register_public_inputs(self.h);
cb.register_public_inputs(self.min_val);
cb.register_public_inputs(self.max_val);
cb.register_public_inputs(self.pri_idx_val);
cb.register_public_inputs(self.idx_ids);
cb.register_public_inputs(self.min_cnt);
cb.register_public_inputs(self.max_cnt);
cb.register_public_inputs(self.offset_range_min);
cb.register_public_inputs(self.offset_range_max);
cb.register_public_inputs(self.acc);
}
}

impl<'a, const S: usize> PublicInputs<'a, Target, S> {
pub fn tree_hash_target(&self) -> HashOutTarget {
HashOutTarget::try_from(self.to_tree_hash_raw()).unwrap()
}

pub fn min_value_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_min_value_raw())
}

pub fn max_value_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_max_value_raw())
}

pub fn primary_index_value_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_primary_index_value_raw())
}

pub fn index_ids_target(&self) -> [Target; 2] {
self.to_index_ids_raw().try_into().unwrap()
}

pub fn min_counter_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_min_counter_raw())
}

pub fn max_counter_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_max_counter_raw())
}

pub fn offset_range_min_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_offset_range_min_raw())
}

pub fn offset_range_max_target(&self) -> UInt256Target {
UInt256Target::from_targets(self.to_offset_range_max_raw())
}

pub fn accumulator_target(&self) -> CurveTarget {
CurveTarget::from_targets(self.to_accumulator_raw())
}
}

#[cfg(test)]
mod tests {
use super::*;
use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F};
use mp2_test::{
circuit::{run_circuit, UserCircuit},
utils::random_vector,
};
use plonky2::{
iop::{
target::Target,
witness::{PartialWitness, WitnessWrite},
},
plonk::circuit_builder::CircuitBuilder,
};

const S: usize = 10;
#[derive(Clone, Debug)]
struct TestPublicInputs<'a> {
pis: &'a [F],
}

impl<'a> UserCircuit<F, D> for TestPublicInputs<'a> {
type Wires = Vec<Target>;

fn build(c: &mut CircuitBuilder<F, D>) -> Self::Wires {
let targets = c.add_virtual_target_arr::<{ PublicInputs::<Target, S>::total_len() }>();
let pi_targets = PublicInputs::<Target, S>::from_slice(targets.as_slice());
pi_targets.register_args(c);
pi_targets.to_vec()
}

fn prove(&self, pw: &mut PartialWitness<F>, wires: &Self::Wires) {
pw.set_target_arr(wires, self.pis)
}
}

#[test]
fn test_results_extraction_public_inputs() {
let pis_raw = random_vector::<u32>(PublicInputs::<F, S>::total_len()).to_fields();

// use public inputs in circuit
let test_circuit = TestPublicInputs { pis: &pis_raw };
let proof = run_circuit::<F, D, C, _>(test_circuit);
assert_eq!(proof.public_inputs, pis_raw);

// check public inputs are constructed correctly
let pis = PublicInputs::<F, S>::from_slice(&proof.public_inputs);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::TreeHash)],
pis.to_tree_hash_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MinValue)],
pis.to_min_value_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MaxValue)],
pis.to_max_value_raw(),
);
assert_eq!(
&pis_raw
[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::PrimaryIndexValue)],
pis.to_primary_index_value_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::IndexIds)],
pis.to_index_ids_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MinCounter)],
pis.to_min_counter_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::MaxCounter)],
pis.to_max_counter_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::OffsetRangeMin)],
pis.to_offset_range_min_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::OffsetRangeMax)],
pis.to_offset_range_max_raw(),
);
assert_eq!(
&pis_raw[PublicInputs::<F, S>::to_range(ResultsExtractionPublicInputs::Accumulator)],
pis.to_accumulator_raw(),
);
}
}
1 change: 1 addition & 0 deletions verifiable-db/src/results_tree/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub(crate) mod construction;
pub(crate) mod extraction;

0 comments on commit 25d5e4c

Please sign in to comment.