Skip to content

Commit

Permalink
Implement Karatsuba multiplication for Uint and BoxedUint (#649)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead authored Aug 16, 2024
1 parent 2952c76 commit de72555
Show file tree
Hide file tree
Showing 5 changed files with 590 additions and 38 deletions.
5 changes: 5 additions & 0 deletions src/const_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ impl ConstChoice {
Self(self.0 & other.0)
}

#[inline]
pub(crate) const fn xor(&self, other: Self) -> Self {
Self(self.0 ^ other.0)
}

/// Return `b` if `self` is truthy, otherwise return `a`.
#[inline]
pub(crate) const fn select_word(&self, a: Word, b: Word) -> Word {
Expand Down
49 changes: 46 additions & 3 deletions src/uint/boxed/mul.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! [`BoxedUint`] multiplication operations.
use crate::{
uint::mul::{mul_limbs, square_limbs},
uint::mul::{
karatsuba::{karatsuba_mul_limbs, karatsuba_square_limbs, KARATSUBA_MIN_STARTING_LIMBS},
mul_limbs, square_limbs,
},
BoxedUint, CheckedMul, Limb, WideningMul, Wrapping, WrappingMul, Zero,
};
use core::ops::{Mul, MulAssign};
Expand All @@ -12,7 +15,18 @@ impl BoxedUint {
///
/// Returns a widened output with a limb count equal to the sums of the input limb counts.
pub fn mul(&self, rhs: &Self) -> Self {
let mut limbs = vec![Limb::ZERO; self.nlimbs() + rhs.nlimbs()];
let size = self.nlimbs() + rhs.nlimbs();
let overlap = self.nlimbs().min(rhs.nlimbs());

if self.nlimbs().min(rhs.nlimbs()) >= KARATSUBA_MIN_STARTING_LIMBS {
let mut limbs = vec![Limb::ZERO; size + overlap * 2];
let (out, scratch) = limbs.as_mut_slice().split_at_mut(size);
karatsuba_mul_limbs(&self.limbs, &rhs.limbs, out, scratch);
limbs.truncate(size);
return limbs.into();
}

let mut limbs = vec![Limb::ZERO; size];
mul_limbs(&self.limbs, &rhs.limbs, &mut limbs);
limbs.into()
}
Expand All @@ -24,7 +38,17 @@ impl BoxedUint {

/// Multiply `self` by itself.
pub fn square(&self) -> Self {
let mut limbs = vec![Limb::ZERO; self.nlimbs() * 2];
let size = self.nlimbs() * 2;

if self.nlimbs() >= KARATSUBA_MIN_STARTING_LIMBS * 2 {
let mut limbs = vec![Limb::ZERO; size * 2];
let (out, scratch) = limbs.as_mut_slice().split_at_mut(size);
karatsuba_square_limbs(&self.limbs, out, scratch);
limbs.truncate(size);
return limbs.into();
}

let mut limbs = vec![Limb::ZERO; size];
square_limbs(&self.limbs, &mut limbs);
limbs.into()
}
Expand Down Expand Up @@ -144,4 +168,23 @@ mod tests {
}
}
}

#[cfg(feature = "rand_core")]
#[test]
fn mul_cmp() {
use crate::RandomBits;
use rand_core::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);

for _ in 0..50 {
let a = BoxedUint::random_bits(&mut rng, 4096);
assert_eq!(a.mul(&a), a.square(), "a = {a}");
}

for _ in 0..50 {
let a = BoxedUint::random_bits(&mut rng, 4096);
let b = BoxedUint::random_bits(&mut rng, 5000);
assert_eq!(a.mul(&b), b.mul(&a), "a={a}, b={b}");
}
}
}
117 changes: 83 additions & 34 deletions src/uint/mul.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//! [`Uint`] multiplication operations.
// TODO(tarcieri): use Karatsuba for better performance

use self::karatsuba::UintKaratsubaMul;
use crate::{
Checked, CheckedMul, Concat, ConcatMixed, Limb, Uint, WideningMul, Wrapping, WrappingMul, Zero,
};
use core::ops::{Mul, MulAssign};
use subtle::CtOption;

pub(crate) mod karatsuba;

/// Implement the core schoolbook multiplication algorithm.
///
/// This is implemented as a macro to abstract over `const fn` and boxed use cases, since the latter
Expand All @@ -26,18 +27,15 @@ macro_rules! impl_schoolbook_multiplication {
while i < $lhs.len() {
let mut j = 0;
let mut carry = Limb::ZERO;
let xi = $lhs[i];

while j < $rhs.len() {
let k = i + j;

if k >= $lhs.len() {
let (n, c) = $hi[k - $lhs.len()].mac($lhs[i], $rhs[j], carry);
$hi[k - $lhs.len()] = n;
carry = c;
($hi[k - $lhs.len()], carry) = $hi[k - $lhs.len()].mac(xi, $rhs[j], carry);
} else {
let (n, c) = $lo[k].mac($lhs[i], $rhs[j], carry);
$lo[k] = n;
carry = c;
($lo[k], carry) = $lo[k].mac(xi, $rhs[j], carry);
}

j += 1;
Expand Down Expand Up @@ -72,18 +70,15 @@ macro_rules! impl_schoolbook_squaring {
while i < $limbs.len() {
let mut j = 0;
let mut carry = Limb::ZERO;
let xi = $limbs[i];

while j < i {
let k = i + j;

if k >= $limbs.len() {
let (n, c) = $hi[k - $limbs.len()].mac($limbs[i], $limbs[j], carry);
$hi[k - $limbs.len()] = n;
carry = c;
($hi[k - $limbs.len()], carry) = $hi[k - $limbs.len()].mac(xi, $limbs[j], carry);
} else {
let (n, c) = $lo[k].mac($limbs[i], $limbs[j], carry);
$lo[k] = n;
carry = c;
($lo[k], carry) = $lo[k].mac(xi, $limbs[j], carry);
}

j += 1;
Expand Down Expand Up @@ -117,24 +112,17 @@ macro_rules! impl_schoolbook_squaring {
let mut carry = Limb::ZERO;
let mut i = 0;
while i < $limbs.len() {
let xi = $limbs[i];
if (i * 2) < $limbs.len() {
let (n, c) = $lo[i * 2].mac($limbs[i], $limbs[i], carry);
$lo[i * 2] = n;
carry = c;
($lo[i * 2], carry) = $lo[i * 2].mac(xi, xi, carry);
} else {
let (n, c) = $hi[i * 2 - $limbs.len()].mac($limbs[i], $limbs[i], carry);
$hi[i * 2 - $limbs.len()] = n;
carry = c;
($hi[i * 2 - $limbs.len()], carry) = $hi[i * 2 - $limbs.len()].mac(xi, xi, carry);
}

if (i * 2 + 1) < $limbs.len() {
let (n, c) = $lo[i * 2 + 1].overflowing_add(carry);
$lo[i * 2 + 1] = n;
carry = c;
($lo[i * 2 + 1], carry) = $lo[i * 2 + 1].overflowing_add(carry);
} else {
let (n, c) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry);
$hi[i * 2 + 1 - $limbs.len()] = n;
carry = c;
($hi[i * 2 + 1 - $limbs.len()], carry) = $hi[i * 2 + 1 - $limbs.len()].overflowing_add(carry);
}

i += 1;
Expand All @@ -161,10 +149,27 @@ impl<const LIMBS: usize> Uint<LIMBS> {
&self,
rhs: &Uint<RHS_LIMBS>,
) -> (Self, Uint<RHS_LIMBS>) {
let mut lo = Self::ZERO;
let mut hi = Uint::<RHS_LIMBS>::ZERO;
impl_schoolbook_multiplication!(&self.limbs, &rhs.limbs, lo.limbs, hi.limbs);
(lo, hi)
if LIMBS == RHS_LIMBS {
if LIMBS == 128 {
let (a, b) = UintKaratsubaMul::<128>::multiply(&self.limbs, &rhs.limbs);
// resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
return (a.resize(), b.resize());
}
if LIMBS == 64 {
let (a, b) = UintKaratsubaMul::<64>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
if LIMBS == 32 {
let (a, b) = UintKaratsubaMul::<32>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
if LIMBS == 16 {
let (a, b) = UintKaratsubaMul::<16>::multiply(&self.limbs, &rhs.limbs);
return (a.resize(), b.resize());
}
}

uint_mul_limbs(&self.limbs, &rhs.limbs)
}

/// Perform wrapping multiplication, discarding overflow.
Expand All @@ -180,10 +185,17 @@ impl<const LIMBS: usize> Uint<LIMBS> {

/// Square self, returning a "wide" result in two parts as (lo, hi).
pub const fn square_wide(&self) -> (Self, Self) {
let mut lo = Self::ZERO;
let mut hi = Self::ZERO;
impl_schoolbook_squaring!(&self.limbs, lo.limbs, hi.limbs);
(lo, hi)
if LIMBS == 128 {
let (a, b) = UintKaratsubaMul::<128>::square(&self.limbs);
// resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
return (a.resize(), b.resize());
}
if LIMBS == 64 {
let (a, b) = UintKaratsubaMul::<64>::square(&self.limbs);
return (a.resize(), b.resize());
}

uint_square_limbs(&self.limbs)
}
}

Expand Down Expand Up @@ -295,6 +307,30 @@ impl<const LIMBS: usize> WrappingMul for Uint<LIMBS> {
}
}

/// Helper method to perform schoolbook multiplication
#[inline]
pub(crate) const fn uint_mul_limbs<const LIMBS: usize, const RHS_LIMBS: usize>(
lhs: &[Limb],
rhs: &[Limb],
) -> (Uint<LIMBS>, Uint<RHS_LIMBS>) {
debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS);
let mut lo: Uint<LIMBS> = Uint::<LIMBS>::ZERO;
let mut hi = Uint::<RHS_LIMBS>::ZERO;
impl_schoolbook_multiplication!(lhs, rhs, lo.limbs, hi.limbs);
(lo, hi)
}

/// Helper method to perform schoolbook multiplication
#[inline]
pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
limbs: &[Limb],
) -> (Uint<LIMBS>, Uint<LIMBS>) {
let mut lo = Uint::<LIMBS>::ZERO;
let mut hi = Uint::<LIMBS>::ZERO;
impl_schoolbook_squaring!(limbs, lo.limbs, hi.limbs);
(lo, hi)
}

/// Wrapper function used by `BoxedUint`
#[cfg(feature = "alloc")]
pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) {
Expand Down Expand Up @@ -402,4 +438,17 @@ mod tests {
assert_eq!(lo, U256::ONE);
assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
}

#[cfg(feature = "rand_core")]
#[test]
fn mul_cmp() {
use crate::{Random, U4096};
use rand_core::SeedableRng;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);

for _ in 0..50 {
let a = U4096::random(&mut rng);
assert_eq!(a.split_mul(&a), a.square_wide(), "a = {a}");
}
}
}
Loading

0 comments on commit de72555

Please sign in to comment.