1
- use num_complex:: Complex ;
2
-
3
1
use super :: LINEAR_IMPL_THRESHOLD ;
4
2
use crate :: internal_prelude:: * ;
5
3
@@ -108,33 +106,7 @@ pub fn sum<T: ComplexField>(mut mat: MatRef<'_, T>) -> T {
108
106
109
107
if try_const ! { T :: SIMD_CAPABILITIES . is_simd( ) } {
110
108
if let Some ( mat) = mat. try_as_col_major ( ) {
111
- if try_const ! { T :: IS_NATIVE_C32 } {
112
- let mat: MatRef < ' _ , Complex < f32 > , usize , usize , ContiguousFwd > = unsafe { crate :: hacks:: coerce ( mat) } ;
113
- let mat = unsafe {
114
- MatRef :: < ' _ , f32 , usize , usize , ContiguousFwd > :: from_raw_parts (
115
- mat. as_ptr ( ) as * const f32 ,
116
- 2 * mat. nrows ( ) ,
117
- mat. ncols ( ) ,
118
- ContiguousFwd ,
119
- mat. col_stride ( ) . wrapping_mul ( 2 ) ,
120
- )
121
- } ;
122
- return unsafe { crate :: hacks:: coerce ( sum_simd_pairwise_cols :: < f32 > ( mat) ) } ;
123
- } else if try_const ! { T :: IS_NATIVE_C64 } {
124
- let mat: MatRef < ' _ , Complex < f64 > , usize , usize , ContiguousFwd > = unsafe { crate :: hacks:: coerce ( mat) } ;
125
- let mat = unsafe {
126
- MatRef :: < ' _ , f64 , usize , usize , ContiguousFwd > :: from_raw_parts (
127
- mat. as_ptr ( ) as * const f64 ,
128
- 2 * mat. nrows ( ) ,
129
- mat. ncols ( ) ,
130
- ContiguousFwd ,
131
- mat. col_stride ( ) . wrapping_mul ( 2 ) ,
132
- )
133
- } ;
134
- return unsafe { crate :: hacks:: coerce ( sum_simd_pairwise_cols :: < f64 > ( mat) ) } ;
135
- } else {
136
- return sum_simd_pairwise_cols ( mat) ;
137
- }
109
+ return sum_simd_pairwise_cols ( mat) ;
138
110
}
139
111
}
140
112
@@ -154,27 +126,53 @@ mod tests {
154
126
use crate :: { Col , Mat , assert, unzip, zip} ;
155
127
156
128
#[ test]
157
- fn test_sum ( ) {
129
+ fn test_sum_real ( ) {
158
130
let relative_err = |a : f64 , b : f64 | ( a - b) . abs ( ) / f64:: max ( a. abs ( ) , b. abs ( ) ) ;
159
131
160
- for ( m, n) in [ ( 9 , 10 ) , ( 1023 , 5 ) , ( 42 , 1 ) ] {
132
+ for ( m, n) in [ ( 9 , 10 ) , ( 1023 , 1024 ) , ( 42 , 1 ) ] {
161
133
for factor in [ 0.0 , 1.0 , 1e30 , 1e250 , 1e-30 , 1e-250 ] {
162
134
let mat = Mat :: from_fn ( m, n, |i, j| factor * ( ( i + j) as f64 ) ) ;
163
135
let mut target = 0.0 ;
164
- zip ! ( mat. as_ref ( ) ) . for_each ( |unzip ! ( x) | {
136
+ zip ! ( mat. rb ( ) ) . for_each ( |unzip ! ( x) | {
165
137
target += x;
166
138
} ) ;
167
139
168
140
if factor == 0.0 {
169
- assert ! ( sum( mat. as_ref ( ) ) == target) ;
141
+ assert ! ( sum( mat. rb ( ) ) == target) ;
170
142
} else {
171
- assert ! ( relative_err( sum( mat. as_ref ( ) ) , target) < 1e-14 ) ;
143
+ assert ! ( relative_err( sum( mat. rb ( ) ) , target) < 1e-13 ) ;
172
144
}
173
145
}
174
146
}
175
147
176
- let mat = Col :: from_fn ( 10000000 , |_| 0.3 ) ;
148
+ let col = Col :: from_fn ( 10000000 , |_| 0.3 ) ;
177
149
let target = 0.3 * 10000000.0f64 ;
178
- assert ! ( relative_err( sum( mat. as_ref( ) . as_mat( ) ) , target) < 1e-14 ) ;
150
+ assert ! ( relative_err( sum( col. as_mat( ) ) , target) < 1e-14 ) ;
151
+ }
152
+
153
+ #[ test]
154
+ fn test_sum_cplx ( ) {
155
+ let relative_err = |a : c64 , b : c64 | abs ( & ( a - b) ) / f64:: max ( abs ( & a) , abs ( & b) ) ;
156
+
157
+ for ( m, n) in [ ( 9 , 10 ) , ( 1023 , 5 ) , ( 42 , 1 ) ] {
158
+ for factor in [ 0.0 , 1.0 , 1e30 , 1e250 , 1e-30 , 1e-250 ] {
159
+ let mat = Mat :: from_fn ( m, n, |i, j| {
160
+ let i = i as isize ;
161
+ let j = j as isize ;
162
+
163
+ c64:: new ( factor * ( ( i + j) as f64 ) , factor * ( ( i - j) as f64 ) )
164
+ } ) ;
165
+ let mut target = c64:: ZERO ;
166
+ zip ! ( mat. rb( ) ) . for_each ( |unzip ! ( x) | {
167
+ target += x;
168
+ } ) ;
169
+
170
+ if factor == 0.0 {
171
+ assert ! ( sum( mat. rb( ) ) == target) ;
172
+ } else {
173
+ assert ! ( relative_err( sum( mat. rb( ) ) , target) < 1e-14 ) ;
174
+ }
175
+ }
176
+ }
179
177
}
180
178
}
0 commit comments