2424#include < stddef.h>
2525#include < stdint.h>
2626
27+ #include < cmath> // std::abs
28+
2729#include " hwy/aligned_allocator.h"
2830
2931// clang-format off
@@ -52,21 +54,23 @@ HWY_NOINLINE void SimpleMatVecAdd(const MatT* HWY_RESTRICT mat,
5254 ThreadPool& pool) {
5355 if (add) {
5456 pool.Run (0 , rows, [=](uint64_t r, size_t /* thread*/ ) {
55- T dot = ConvertScalarTo<T>( 0 ) ;
57+ double dot = 0.0 ;
5658 for (size_t c = 0 ; c < cols; c++) {
5759 // For reasons unknown, fp16 += does not compile on clang (Arm).
58- dot = ConvertScalarTo<T>(dot + mat[r * cols + c] * vec[c]);
60+ dot += ConvertScalarTo<double >(mat[r * cols + c]) *
61+ ConvertScalarTo<double >(vec[c]);
5962 }
60- out[r] = dot + add[r];
63+ out[r] = ConvertScalarTo<T>( dot + ConvertScalarTo< double >( add[r])) ;
6164 });
6265 } else {
6366 pool.Run (0 , rows, [=](uint64_t r, size_t /* thread*/ ) {
64- T dot = ConvertScalarTo<T>( 0 ) ;
67+ double dot = 0.0 ;
6568 for (size_t c = 0 ; c < cols; c++) {
6669 // For reasons unknown, fp16 += does not compile on clang (Arm).
67- dot = ConvertScalarTo<T>(dot + mat[r * cols + c] * vec[c]);
70+ dot += ConvertScalarTo<double >(mat[r * cols + c]) *
71+ ConvertScalarTo<double >(vec[c]);
6872 }
69- out[r] = dot;
73+ out[r] = ConvertScalarTo<T>( dot) ;
7074 });
7175 }
7276}
@@ -118,22 +122,33 @@ HWY_MAYBE_UNUSED HWY_NOINLINE void SimpleMatVecAdd(
118122 }
119123}
120124
125+ // Workaround for incorrect codegen on Arm, which results in values of `av`
126+ // >= 1E10. Can also be prevented by calling `Print(du, indices)`.
127+ #if HWY_ARCH_ARM && HWY_COMPILER_CLANG
128+ #define GENERATE_INLINE HWY_NOINLINE
129+ #else
130+ #define GENERATE_INLINE HWY_INLINE
131+ #endif
132+
121133struct GenerateMod {
122134 template <class D , HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_GT_D(D, 1 )>
123- Vec<D> operator ()(D d, Vec<RebindToUnsigned<D>> indices) const {
135+ GENERATE_INLINE Vec<D> operator ()(D d,
136+ Vec<RebindToUnsigned<D>> indices) const {
124137 const RebindToUnsigned<D> du;
125138 return Reverse2 (d, ConvertTo (d, And (indices, Set (du, 0xF ))));
126139 }
127140
128141 template <class D , HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_LE_D(D, 1 )>
129- Vec<D> operator ()(D d, Vec<RebindToUnsigned<D>> indices) const {
142+ GENERATE_INLINE Vec<D> operator ()(D d,
143+ Vec<RebindToUnsigned<D>> indices) const {
130144 const RebindToUnsigned<D> du;
131145 return ConvertTo (d, And (indices, Set (du, 0xF )));
132146 }
133147
134148 // Requires >= 4 bf16 lanes for float32 Reverse2.
135149 template <class D , HWY_IF_BF16_D(D), HWY_IF_LANES_GT_D(D, 2 )>
136- Vec<D> operator ()(D d, Vec<RebindToUnsigned<D>> indices) const {
150+ GENERATE_INLINE Vec<D> operator ()(D d,
151+ Vec<RebindToUnsigned<D>> indices) const {
137152 const RebindToUnsigned<D> du;
138153 const RebindToSigned<D> di;
139154 const RepartitionToWide<decltype (di)> dw;
@@ -146,9 +161,10 @@ struct GenerateMod {
146161
147162 // For one or two lanes, we don't have OrderedDemote2To nor Reverse2.
148163 template <class D , HWY_IF_BF16_D(D), HWY_IF_LANES_LE_D(D, 2 )>
149- Vec<D> operator ()(D d, Vec<RebindToUnsigned<D>> indices) const {
164+ GENERATE_INLINE Vec<D> operator ()(D d,
165+ Vec<RebindToUnsigned<D>> indices) const {
150166 const Rebind<float , D> df;
151- return DemoteTo (d, Set (df, GetLane (indices)));
167+ return DemoteTo (d, Set (df, static_cast < float >( GetLane (indices) )));
152168 }
153169};
154170
@@ -194,15 +210,19 @@ class TestMatVecAdd {
194210 for (size_t i = 0 ; i < kRows ; ++i) {
195211 const double exp = ConvertScalarTo<double >(expected[i]);
196212 const double act = ConvertScalarTo<double >(actual[i]);
197- const double tolerance =
198- exp * 20 * 1.0 /
199- (1ULL << HWY_MIN (MantissaBits<MatT>(), MantissaBits<VecT>()));
200- if (!(exp - tolerance <= act && act <= exp + tolerance)) {
213+ const double epsilon =
214+ 1.0 / (1ULL << HWY_MIN (MantissaBits<MatT>(), MantissaBits<VecT>()));
215+ const double tolerance = exp * 20.0 / epsilon;
216+ const double l1 = std::abs (exp - act);
217+ const double rel = exp == 0.0 ? 0.0 : l1 / exp;
218+
219+ if (l1 > tolerance && rel > epsilon) {
201220 fprintf (stderr,
202- " %s/%s %zu x %zu, %s: mismatch at %zu %f %f; tol %f\n " ,
221+ " %s/%s %zu x %zu, %s: mismatch at %zu: %E != %E; "
222+ " tol %f l1 %f rel %E\n " ,
203223 TypeName (MatT (), 1 ).c_str (), TypeName (VecT (), 1 ).c_str (),
204224 kRows , kCols , (with_add ? " with add" : " without add" ), i, exp,
205- act, tolerance);
225+ act, tolerance, l1, rel );
206226 HWY_ASSERT (0 );
207227 }
208228 }
0 commit comments