Skip to content

Commit

Permalink
Compute rho_s crt once during folding (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-a-klein authored Dec 10, 2024
1 parent 251f7b8 commit 75ed3a2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 29 deletions.
26 changes: 15 additions & 11 deletions latticefold/src/nifs/folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use ark_std::iter::successors;

use ark_std::iterable::Iterable;
use cyclotomic_rings::rings::SuitableRing;
use lattirust_ring::cyclotomic_ring::CRT;

use super::error::FoldingError;
use crate::ark_base::*;
Expand Down Expand Up @@ -114,13 +113,11 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingProver<N
Ok(eta_s)
}

fn compute_f_0(rho_s: &Vec<NTT::CoefficientRepresentation>, w_s: &[Witness<NTT>]) -> Vec<NTT> {
fn compute_f_0(rho_s: &[NTT], w_s: &[Witness<NTT>]) -> Vec<NTT> {
rho_s
.iter()
.zip(w_s)
.fold(vec![NTT::ZERO; w_s[0].f.len()], |acc, (&rho_i, w_i)| {
let rho_i: NTT = rho_i.crt();

acc.into_iter()
.zip(w_i.f.iter())
.map(|(acc_j, w_ij)| acc_j + rho_i * w_ij)
Expand Down Expand Up @@ -190,18 +187,19 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingProver<NTT
eta_s.iter().for_each(|etas| transcript.absorb_slice(etas));

// Step 5 get rho challenges
let rho_s = get_rhos::<_, _, P>(transcript);
let (rho_s_coeff, rho_s) = get_rhos::<_, _, P>(transcript);

let f_0: Vec<NTT> = Self::compute_f_0(&rho_s, &w_s);

// Step 6 compute v0, u0, y0, x_w0
let (v_0, cm_0, u_0, x_0) = compute_v0_u0_x0_cm_0(&rho_s, &theta_s, cm_i_s, &eta_s, ccs);
let (v_0, cm_0, u_0, x_0) =
compute_v0_u0_x0_cm_0(rho_s_coeff, rho_s, &theta_s, cm_i_s, &eta_s, ccs);

// Step 7: Compute f0 and Witness_0
let h = x_0.last().copied().ok_or(FoldingError::IncorrectLength)?;

let lcccs = prepare_public_output(r_0, v_0, cm_0, u_0, x_0, h);

let f_0: Vec<NTT> = Self::compute_f_0(&rho_s, &w_s);

let w_0 = Witness::from_f::<P>(f_0);

let folding_proof = FoldingProof {
Expand Down Expand Up @@ -364,11 +362,17 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingVerifier<N
.eta_s
.iter()
.for_each(|etas| transcript.absorb_slice(etas));
let rho_s = get_rhos::<_, _, P>(transcript);
let (rho_s_coeff, rho_s) = get_rhos::<_, _, P>(transcript);

// Step 6
let (v_0, cm_0, u_0, x_0) =
compute_v0_u0_x0_cm_0(&rho_s, &proof.theta_s, cm_i_s, &proof.eta_s, ccs);
let (v_0, cm_0, u_0, x_0) = compute_v0_u0_x0_cm_0(
rho_s_coeff,
rho_s,
&proof.theta_s,
cm_i_s,
&proof.eta_s,
ccs,
);

// Step 7: Compute f0 and Witness_0

Expand Down
15 changes: 7 additions & 8 deletions latticefold/src/nifs/folding/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use lattirust_ring::cyclotomic_ring::models::{
frog_ring::RqNTT as FrogRqNTT, goldilocks::RqNTT as GoldilocksRqNTT,
stark_prime::RqNTT as StarkRqNTT,
};
use lattirust_ring::cyclotomic_ring::{CRT, ICRT};
use lattirust_ring::cyclotomic_ring::ICRT;
use lattirust_ring::Ring;
use num_traits::{One, Zero};

Expand Down Expand Up @@ -441,7 +441,7 @@ fn test_get_rhos() {
let (_, _, mut transcript, _, _, _) = setup_test_environment::<RqNTT, CS, DP, C, W>(false);
let mut transcript_clone = transcript.clone();

let rho_s = get_rhos::<_, _, DP>(&mut transcript);
let (rho_s_coeff, rho_s) = get_rhos::<_, _, DP>(&mut transcript);

// Compute expected result
transcript_clone.absorb_field_element(&<_>::from_base_prime_field(
Expand All @@ -454,7 +454,7 @@ fn test_get_rhos() {
assert!(!rho_s.is_empty(), "Rhos vector should not be empty");
assert_eq!(rho_s.len(), 2 * DP::K, "Mismatch in Rhos length");
assert_eq!(
rho_s, expected_rhos,
rho_s_coeff, expected_rhos,
"Rhosvector does not match expected evaluations"
);
}
Expand Down Expand Up @@ -519,8 +519,9 @@ fn test_prepare_public_output() {
.for_each(|thetas| transcript.absorb_slice(thetas));
eta_s.iter().for_each(|etas| transcript.absorb_slice(etas));

let rho_s = get_rhos::<_, _, DP>(&mut transcript);
let (v_0, cm_0, u_0, x_0) = compute_v0_u0_x0_cm_0(&rho_s, &theta_s, &lccs, &eta_s, &ccs);
let (rho_s_coeff, rho_s) = get_rhos::<_, _, DP>(&mut transcript);
let (v_0, cm_0, u_0, x_0) =
compute_v0_u0_x0_cm_0(rho_s_coeff, rho_s, &theta_s, &lccs, &eta_s, &ccs);
let expected_x_0 = x_0[0..x_0.len() - 1].to_vec();
let h = x_0.last().copied().unwrap();

Expand Down Expand Up @@ -599,7 +600,7 @@ fn test_compute_f_0() {
.for_each(|thetas| transcript.absorb_slice(thetas));
eta_s.iter().for_each(|etas| transcript.absorb_slice(etas));

let rho_s = get_rhos::<_, _, DP>(&mut transcript);
let (_, rho_s) = get_rhos::<_, _, DP>(&mut transcript);

let f_0: Vec<RqNTT> =
LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::compute_f_0(&rho_s, &wit_s);
Expand All @@ -609,8 +610,6 @@ fn test_compute_f_0() {
.iter()
.zip(&wit_s)
.fold(vec![RqNTT::ZERO; wit_s[0].f.len()], |acc, (&rho_i, w_i)| {
let rho_i: RqNTT = rho_i.crt();

acc.into_iter()
.zip(w_i.f.iter())
.map(|(acc_j, w_ij)| acc_j + rho_i * w_ij)
Expand Down
21 changes: 11 additions & 10 deletions latticefold/src/nifs/folding/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ pub(super) fn get_rhos<
P: DecompositionParams,
>(
transcript: &mut T,
) -> Vec<R::CoefficientRepresentation> {
) -> (Vec<R::CoefficientRepresentation>, Vec<R>) {
transcript.absorb_field_element(&<R::BaseRing as Field>::from_base_prime_field(
<R::BaseRing as Field>::BasePrimeField::from_be_bytes_mod_order(b"rho_s"),
));

let mut rhos = transcript.get_small_challenges((2 * P::K) - 1); // Note that we are missing the first element
rhos.push(R::CoefficientRepresentation::ONE);

rhos
let mut rhos_coeff = transcript.get_small_challenges((2 * P::K) - 1); // Note that we are missing the first element
rhos_coeff.push(R::CoefficientRepresentation::ONE);
let rhos = CRT::elementwise_crt(rhos_coeff.clone());
(rhos_coeff, rhos)
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -277,18 +277,19 @@ pub(super) fn compute_sumcheck_claim_expected_value<NTT: Ring, P: DecompositionP
}

pub(super) fn compute_v0_u0_x0_cm_0<const C: usize, NTT: SuitableRing>(
rho_s: &[NTT::CoefficientRepresentation],
rho_s_coeff: Vec<NTT::CoefficientRepresentation>,
rho_s: Vec<NTT>,
theta_s: &[Vec<NTT>],
cm_i_s: &[LCCCS<C, NTT>],
eta_s: &[Vec<NTT>],
ccs: &CCS<NTT>,
) -> (Vec<NTT>, Commitment<C, NTT>, Vec<NTT>, Vec<NTT>) {
let v_0: Vec<NTT> = rot_lin_combination(rho_s, theta_s);
let v_0: Vec<NTT> = rot_lin_combination(&rho_s_coeff, theta_s);

let cm_0: Commitment<C, NTT> = rho_s
.iter()
.zip(cm_i_s.iter())
.map(|(&rho_i, cm_i)| cm_i.cm.clone() * rho_i.crt())
.map(|(&rho_i, cm_i)| cm_i.cm.clone() * rho_i)
.sum();

let u_0: Vec<NTT> = rho_s
Expand All @@ -297,7 +298,7 @@ pub(super) fn compute_v0_u0_x0_cm_0<const C: usize, NTT: SuitableRing>(
.map(|(&rho_i, etas_i)| {
etas_i
.iter()
.map(|etas_i_j| rho_i.crt() * etas_i_j)
.map(|etas_i_j| rho_i * etas_i_j)
.collect::<Vec<NTT>>()
})
.fold(vec![NTT::zero(); ccs.l], |mut acc, rho_i_times_etas_i| {
Expand All @@ -316,7 +317,7 @@ pub(super) fn compute_v0_u0_x0_cm_0<const C: usize, NTT: SuitableRing>(
.map(|(&rho_i, cm_i)| {
cm_i.x_w
.iter()
.map(|x_w_i| rho_i.crt() * x_w_i)
.map(|x_w_i| rho_i * x_w_i)
.collect::<Vec<NTT>>()
})
.fold(vec![NTT::zero(); ccs.n], |mut acc, rho_i_times_x_w_i| {
Expand Down

0 comments on commit 75ed3a2

Please sign in to comment.