Skip to content

Commit 7bc1567

Browse files
fix sum and activate v3 feature in pulp
1 parent 56028f4 commit 7bc1567

File tree

2 files changed

+35
-37
lines changed

2 files changed

+35
-37
lines changed

faer/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ num-traits = { version = "0.2.19", default-features = false }
2828

2929
gemm = { version = "0.18.2", default-features = false }
3030
dyn-stack = { version = "0.13.0", default-features = false, features = ["core-error", "alloc"] }
31-
pulp = { version = "0.21.3", default-features = false }
31+
pulp = { version = "0.21.3", default-features = false, features = ["x86-v3"] }
3232
equator = { version = "0.4.2" }
3333

3434
faer-traits = { path = "../faer-traits", version = "0.21.0" }

faer/src/linalg/reductions/sum.rs

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use num_complex::Complex;
2-
31
use super::LINEAR_IMPL_THRESHOLD;
42
use crate::internal_prelude::*;
53

@@ -108,33 +106,7 @@ pub fn sum<T: ComplexField>(mut mat: MatRef<'_, T>) -> T {
108106

109107
if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
110108
if let Some(mat) = mat.try_as_col_major() {
111-
if try_const! { T::IS_NATIVE_C32 } {
112-
let mat: MatRef<'_, Complex<f32>, usize, usize, ContiguousFwd> = unsafe { crate::hacks::coerce(mat) };
113-
let mat = unsafe {
114-
MatRef::<'_, f32, usize, usize, ContiguousFwd>::from_raw_parts(
115-
mat.as_ptr() as *const f32,
116-
2 * mat.nrows(),
117-
mat.ncols(),
118-
ContiguousFwd,
119-
mat.col_stride().wrapping_mul(2),
120-
)
121-
};
122-
return unsafe { crate::hacks::coerce(sum_simd_pairwise_cols::<f32>(mat)) };
123-
} else if try_const! { T::IS_NATIVE_C64 } {
124-
let mat: MatRef<'_, Complex<f64>, usize, usize, ContiguousFwd> = unsafe { crate::hacks::coerce(mat) };
125-
let mat = unsafe {
126-
MatRef::<'_, f64, usize, usize, ContiguousFwd>::from_raw_parts(
127-
mat.as_ptr() as *const f64,
128-
2 * mat.nrows(),
129-
mat.ncols(),
130-
ContiguousFwd,
131-
mat.col_stride().wrapping_mul(2),
132-
)
133-
};
134-
return unsafe { crate::hacks::coerce(sum_simd_pairwise_cols::<f64>(mat)) };
135-
} else {
136-
return sum_simd_pairwise_cols(mat);
137-
}
109+
return sum_simd_pairwise_cols(mat);
138110
}
139111
}
140112

@@ -154,27 +126,53 @@ mod tests {
154126
use crate::{Col, Mat, assert, unzip, zip};
155127

156128
#[test]
157-
fn test_sum() {
129+
fn test_sum_real() {
158130
let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs());
159131

160-
for (m, n) in [(9, 10), (1023, 5), (42, 1)] {
132+
for (m, n) in [(9, 10), (1023, 1024), (42, 1)] {
161133
for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] {
162134
let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64));
163135
let mut target = 0.0;
164-
zip!(mat.as_ref()).for_each(|unzip!(x)| {
136+
zip!(mat.rb()).for_each(|unzip!(x)| {
165137
target += x;
166138
});
167139

168140
if factor == 0.0 {
169-
assert!(sum(mat.as_ref()) == target);
141+
assert!(sum(mat.rb()) == target);
170142
} else {
171-
assert!(relative_err(sum(mat.as_ref()), target) < 1e-14);
143+
assert!(relative_err(sum(mat.rb()), target) < 1e-13);
172144
}
173145
}
174146
}
175147

176-
let mat = Col::from_fn(10000000, |_| 0.3);
148+
let col = Col::from_fn(10000000, |_| 0.3);
177149
let target = 0.3 * 10000000.0f64;
178-
assert!(relative_err(sum(mat.as_ref().as_mat()), target) < 1e-14);
150+
assert!(relative_err(sum(col.as_mat()), target) < 1e-14);
151+
}
152+
153+
#[test]
154+
fn test_sum_cplx() {
155+
let relative_err = |a: c64, b: c64| abs(&(a - b)) / f64::max(abs(&a), abs(&b));
156+
157+
for (m, n) in [(9, 10), (1023, 5), (42, 1)] {
158+
for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] {
159+
let mat = Mat::from_fn(m, n, |i, j| {
160+
let i = i as isize;
161+
let j = j as isize;
162+
163+
c64::new(factor * ((i + j) as f64), factor * ((i - j) as f64))
164+
});
165+
let mut target = c64::ZERO;
166+
zip!(mat.rb()).for_each(|unzip!(x)| {
167+
target += x;
168+
});
169+
170+
if factor == 0.0 {
171+
assert!(sum(mat.rb()) == target);
172+
} else {
173+
assert!(relative_err(sum(mat.rb()), target) < 1e-14);
174+
}
175+
}
176+
}
179177
}
180178
}

0 commit comments

Comments
 (0)