Skip to content

Commit

Permalink
fix sum and activate v3 feature in pulp
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-quinones committed Feb 5, 2025
1 parent 56028f4 commit 7bc1567
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 37 deletions.
2 changes: 1 addition & 1 deletion faer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ num-traits = { version = "0.2.19", default-features = false }

gemm = { version = "0.18.2", default-features = false }
dyn-stack = { version = "0.13.0", default-features = false, features = ["core-error", "alloc"] }
pulp = { version = "0.21.3", default-features = false }
pulp = { version = "0.21.3", default-features = false, features = ["x86-v3"] }
equator = { version = "0.4.2" }

faer-traits = { path = "../faer-traits", version = "0.21.0" }
Expand Down
70 changes: 34 additions & 36 deletions faer/src/linalg/reductions/sum.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use num_complex::Complex;

use super::LINEAR_IMPL_THRESHOLD;
use crate::internal_prelude::*;

Expand Down Expand Up @@ -108,33 +106,7 @@ pub fn sum<T: ComplexField>(mut mat: MatRef<'_, T>) -> T {

if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
if let Some(mat) = mat.try_as_col_major() {
if try_const! { T::IS_NATIVE_C32 } {
let mat: MatRef<'_, Complex<f32>, usize, usize, ContiguousFwd> = unsafe { crate::hacks::coerce(mat) };
let mat = unsafe {
MatRef::<'_, f32, usize, usize, ContiguousFwd>::from_raw_parts(
mat.as_ptr() as *const f32,
2 * mat.nrows(),
mat.ncols(),
ContiguousFwd,
mat.col_stride().wrapping_mul(2),
)
};
return unsafe { crate::hacks::coerce(sum_simd_pairwise_cols::<f32>(mat)) };
} else if try_const! { T::IS_NATIVE_C64 } {
let mat: MatRef<'_, Complex<f64>, usize, usize, ContiguousFwd> = unsafe { crate::hacks::coerce(mat) };
let mat = unsafe {
MatRef::<'_, f64, usize, usize, ContiguousFwd>::from_raw_parts(
mat.as_ptr() as *const f64,
2 * mat.nrows(),
mat.ncols(),
ContiguousFwd,
mat.col_stride().wrapping_mul(2),
)
};
return unsafe { crate::hacks::coerce(sum_simd_pairwise_cols::<f64>(mat)) };
} else {
return sum_simd_pairwise_cols(mat);
}
return sum_simd_pairwise_cols(mat);
}
}

Expand All @@ -154,27 +126,53 @@ mod tests {
use crate::{Col, Mat, assert, unzip, zip};

#[test]
fn test_sum() {
fn test_sum_real() {
let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs());

for (m, n) in [(9, 10), (1023, 5), (42, 1)] {
for (m, n) in [(9, 10), (1023, 1024), (42, 1)] {
for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] {
let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64));
let mut target = 0.0;
zip!(mat.as_ref()).for_each(|unzip!(x)| {
zip!(mat.rb()).for_each(|unzip!(x)| {
target += x;
});

if factor == 0.0 {
assert!(sum(mat.as_ref()) == target);
assert!(sum(mat.rb()) == target);
} else {
assert!(relative_err(sum(mat.as_ref()), target) < 1e-14);
assert!(relative_err(sum(mat.rb()), target) < 1e-13);
}
}
}

let mat = Col::from_fn(10000000, |_| 0.3);
let col = Col::from_fn(10000000, |_| 0.3);
let target = 0.3 * 10000000.0f64;
assert!(relative_err(sum(mat.as_ref().as_mat()), target) < 1e-14);
assert!(relative_err(sum(col.as_mat()), target) < 1e-14);
}

#[test]
fn test_sum_cplx() {
let relative_err = |a: c64, b: c64| abs(&(a - b)) / f64::max(abs(&a), abs(&b));

for (m, n) in [(9, 10), (1023, 5), (42, 1)] {
for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] {
let mat = Mat::from_fn(m, n, |i, j| {
let i = i as isize;
let j = j as isize;

c64::new(factor * ((i + j) as f64), factor * ((i - j) as f64))
});
let mut target = c64::ZERO;
zip!(mat.rb()).for_each(|unzip!(x)| {
target += x;
});

if factor == 0.0 {
assert!(sum(mat.rb()) == target);
} else {
assert!(relative_err(sum(mat.rb()), target) < 1e-14);
}
}
}
}
}

0 comments on commit 7bc1567

Please sign in to comment.