Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ byteorder = "1.4.3"
thiserror = "1.0"
group = "0.13.0"
once_cell = "1.18.0"
itertools = "0.12.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { version = "0.1.4" }
Expand Down
103 changes: 103 additions & 0 deletions src/spartan/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/// Macros to give syntactic sugar for zipWith pattern and variants.
///
/// ```ignore
/// use crate::spartan::zip_with;
/// use itertools::Itertools as _; // we use zip_eq to zip!
/// let v = vec![0, 1, 2];
/// let w = vec![2, 3, 4];
/// let y = vec![4, 5, 6];
///
/// // Using the `zip_with!` macro to zip three iterators together and apply a closure
/// // that sums the elements of each iterator.
/// let res = zip_with!((v.iter(), w.iter(), y.iter()), |a, b, c| a + b + c)
/// .collect::<Vec<_>>();
///
/// println!("{:?}", res); // Output: [6, 9, 12]
/// ```

#[macro_export]
macro_rules! zip_with {
// no iterator projection specified: the macro assumes the arguments *are* iterators
// ```ignore
// zip_with!((iter1, iter2, iter3), |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).map(|(a, (b, c))| a + b + c)
// ```
//
// iterator projection specified: use it on each argument
// ```ignore
// zip_with!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) ->
// vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).map(|(a, (b, c))| a + b + c)
// ````
($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_with!($($f,)? ($e $(, $rest)*), map, $($move)? |$($i),+| $($work)*)
}};
// no iterator projection specified: the macro assumes the arguments *are* iterators
// optional zipping function specified as well: use it instead of map
// ```ignore
// zip_with!((iter1, iter2, iter3), for_each, |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c)
// ```
//
//
// iterator projection specified: use it on each argument
// optional zipping function specified as well: use it instead of map
// ```ignore
// zip_with!(par_iter, (vec1, vec2, vec3), for_each, |a, b, c| a + b + c) ->
// vec1.part_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c)
// ```
($($f:ident,)? ($e:expr $(, $rest:expr)*), $worker:ident, $($move:ident,)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_all!($($f,)? ($e $(, $rest)*))
.$worker($($move)? |$crate::nested_idents!($($i),+)| {
$($work)*
})
}};
}

/// Like `zip_with` but use `for_each` instead of `map`.
#[macro_export]
macro_rules! zip_with_for_each {
// no iterator projection specified: the macro assumes the arguments *are* iterators
// ```ignore
// zip_with_for_each!((iter1, iter2, iter3), |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c)
// ```
//
// iterator projection specified: use it on each argument
// ```ignore
// zip_with_for_each!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) ->
// vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c)
// ````
($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_with!($($f,)? ($e $(, $rest)*), for_each, $($move)? |$($i),+| $($work)*)
}};
}

// Foldright-like nesting for idents (a, b, c) -> (a, (b, c))
#[doc(hidden)]
#[macro_export]
macro_rules! nested_idents {
($a:ident, $b:ident) => {
($a, $b)
};
($first:ident, $($rest:ident),+) => {
($first, $crate::nested_idents!($($rest),+))
};
}

// Fold-right like zipping, with an optional function `f` to apply to each argument
#[doc(hidden)]
#[macro_export]
macro_rules! zip_all {
(($e:expr,)) => {
$e
};
($f:ident, ($e:expr,)) => {
$e.$f()
};
($f:ident, ($first:expr, $second:expr $(, $rest:expr)*)) => {
($first.$f().zip_eq($crate::zip_all!($f, ($second, $( $rest),*))))
};
(($first:expr, $second:expr $(, $rest:expr)*)) => {
($first.zip_eq($crate::zip_all!(($second, $( $rest),*))))
};
}
39 changes: 16 additions & 23 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
//! We also provide direct.rs that allows proving a step circuit directly with either of the two SNARKs.
//!
//! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials.
#[macro_use]
mod macros;
pub mod direct;
pub(crate) mod math;
pub mod polys;
Expand All @@ -14,6 +16,7 @@ mod sumcheck;

use crate::{traits::Engine, Commitment};
use ff::Field;
use itertools::Itertools as _;
use polys::multilinear::SparsePolynomial;
use rayon::{iter::IntoParallelRefIterator, prelude::*};

Expand Down Expand Up @@ -64,20 +67,17 @@ impl<E: Engine> PolyEvalWitness<E> {

let powers_of_s = powers::<E>(s, p_vec.len());

let p = p_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(v, &weight)| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip(v).map(|(x, y)| x + y).collect()
},
);
let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
zip_with!((acc.into_iter(), v), |x, y| x + y).collect()
},
);

PolyEvalWitness { p }
}
Expand Down Expand Up @@ -113,15 +113,8 @@ impl<E: Engine> PolyEvalInstance<E> {
s: &E::Scalar,
) -> PolyEvalInstance<E> {
let powers_of_s = powers::<E>(s, c_vec.len());
let e = e_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(e, p)| *e * p)
.sum();
let c = c_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(c, p)| *c * *p)
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

PolyEvalInstance {
Expand Down
11 changes: 4 additions & 7 deletions src/spartan/polys/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,10 @@ impl<Scalar: PrimeField> EqPolynomial<Scalar> {
let (evals_left, evals_right) = evals.split_at_mut(size);
let (evals_right, _) = evals_right.split_at_mut(size);

evals_left
.par_iter_mut()
.zip(evals_right.par_iter_mut())
.for_each(|(x, y)| {
*y = *x * r;
*x -= &*y;
});
zip_with_for_each!(par_iter_mut, (evals_left, evals_right), |x, y| {
*y = *x * r;
*x -= &*y;
});

size *= 2;
}
Expand Down
29 changes: 11 additions & 18 deletions src/spartan/polys/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::ops::{Add, Index};

use ff::PrimeField;
use itertools::Itertools as _;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
Expand Down Expand Up @@ -65,12 +66,9 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {

let (left, right) = self.Z.split_at_mut(n);

left
.par_iter_mut()
.zip(right.par_iter())
.for_each(|(a, b)| {
*a += *r * (*b - *a);
});
zip_with_for_each!((left.par_iter_mut(), right.par_iter()), |a, b| {
*a += *r * (*b - *a);
});

self.Z.resize(n, Scalar::ZERO);
self.num_vars -= 1;
Expand All @@ -94,12 +92,12 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {

/// Evaluates the polynomial with the given evaluations and point.
pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar {
EqPolynomial::new(r.to_vec())
.evals()
.into_par_iter()
.zip(Z.into_par_iter())
.map(|(a, b)| a * b)
.sum()
zip_with!(
into_par_iter,
(EqPolynomial::new(r.to_vec()).evals(), Z),
|a, b| a * b
)
.sum()
}
}

Expand Down Expand Up @@ -167,12 +165,7 @@ impl<Scalar: PrimeField> Add for MultilinearPolynomial<Scalar> {
return Err("The two polynomials must have the same number of variables");
}

let sum: Vec<Scalar> = self
.Z
.iter()
.zip(other.Z.iter())
.map(|(a, b)| *a + *b)
.collect();
let sum: Vec<Scalar> = zip_with!(iter, (self.Z, other.Z), |a, b| *a + *b).collect();

Ok(MultilinearPolynomial::new(sum))
}
Expand Down
31 changes: 10 additions & 21 deletions src/spartan/ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ use crate::{
snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait},
Engine, TranscriptEngineTrait, TranscriptReprTrait,
},
Commitment, CommitmentKey, CompressedCommitment,
zip_with, Commitment, CommitmentKey, CompressedCommitment,
};
use core::cmp::max;
use ff::Field;
use itertools::Itertools as _;
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -339,13 +340,7 @@ impl<E: Engine> MemorySumcheckInstance<E> {
let inv = batch_invert(&T.par_iter().map(|e| *e + *r).collect::<Vec<E::Scalar>>())?;

// compute inv[i] * TS[i] in parallel
Ok(
inv
.par_iter()
.zip(TS.par_iter())
.map(|(e1, e2)| *e1 * *e2)
.collect::<Vec<_>>(),
)
Ok(zip_with!(par_iter, (inv, TS), |e1, e2| *e1 * e2).collect::<Vec<_>>())
},
|| batch_invert(&W.par_iter().map(|e| *e + *r).collect::<Vec<E::Scalar>>()),
)
Expand Down Expand Up @@ -853,11 +848,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARK<E, EE> {
let coeffs = powers::<E>(&s, claims.len());

// compute the joint claim
let claim = claims
.iter()
.zip(coeffs.iter())
.map(|(c_1, c_2)| *c_1 * c_2)
.sum();
let claim = zip_with!(iter, (claims, coeffs), |c_1, c_2| *c_1 * c_2).sum();

let mut e = claim;
let mut r: Vec<E::Scalar> = Vec::new();
Expand Down Expand Up @@ -1086,14 +1077,12 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for Relax
);

// a sum-check instance to prove the second claim
let val = pk
.S_repr
.val_A
.par_iter()
.zip(pk.S_repr.val_B.par_iter())
.zip(pk.S_repr.val_C.par_iter())
.map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c)
.collect::<Vec<E::Scalar>>();
let val = zip_with!(
par_iter,
(pk.S_repr.val_A, pk.S_repr.val_B, pk.S_repr.val_C),
|v_a, v_b, v_c| *v_a + c * *v_b + c * c * *v_c
)
.collect::<Vec<E::Scalar>>();
let inner_sc_inst = InnerSumcheckInstance {
claim: eval_Az_at_tau + c * eval_Bz_at_tau + c * c * eval_Cz_at_tau,
poly_L_row: MultilinearPolynomial::new(L_row.clone()),
Expand Down
Loading