diff --git a/kernels/volk/volk_32fc_s32f_atan2_32f.h b/kernels/volk/volk_32fc_s32f_atan2_32f.h index cb8281a5..759db24c 100644 --- a/kernels/volk/volk_32fc_s32f_atan2_32f.h +++ b/kernels/volk/volk_32fc_s32f_atan2_32f.h @@ -117,6 +117,7 @@ static inline void volk_32fc_s32f_atan2_32f_a_avx2_fma(float* outputVector, const __m256 pi_2 = _mm256_set1_ps(0x1.921fb6p0f); const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); const __m256 sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + const __m256 zero = _mm256_setzero_ps(); unsigned int number = 0; unsigned int eighth_points = num_points / 8; @@ -133,6 +134,8 @@ static inline void volk_32fc_s32f_atan2_32f_a_avx2_fma(float* outputVector, _mm256_and_ps(y, abs_mask), _mm256_and_ps(x, abs_mask), _CMP_GT_OS); __m256 input = _mm256_div_ps(_mm256_blendv_ps(y, x, swap_mask), _mm256_blendv_ps(x, y, swap_mask)); + __m256 nan_mask = _mm256_cmp_ps(input, input, _CMP_UNORD_Q); + input = _mm256_blendv_ps(input, zero, nan_mask); __m256 result = _m256_arctan_poly_avx2_fma(input); input = @@ -174,6 +177,7 @@ static inline void volk_32fc_s32f_atan2_32f_a_avx2(float* outputVector, const __m256 pi_2 = _mm256_set1_ps(0x1.921fb6p0f); const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); const __m256 sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + const __m256 zero = _mm256_setzero_ps(); unsigned int number = 0; unsigned int eighth_points = num_points / 8; @@ -190,6 +194,8 @@ static inline void volk_32fc_s32f_atan2_32f_a_avx2(float* outputVector, _mm256_and_ps(y, abs_mask), _mm256_and_ps(x, abs_mask), _CMP_GT_OS); __m256 input = _mm256_div_ps(_mm256_blendv_ps(y, x, swap_mask), _mm256_blendv_ps(x, y, swap_mask)); + __m256 nan_mask = _mm256_cmp_ps(input, input, _CMP_UNORD_Q); + input = _mm256_blendv_ps(input, zero, nan_mask); __m256 result = _m256_arctan_poly_avx(input); input = @@ -235,6 +241,7 @@ static inline void volk_32fc_s32f_atan2_32f_u_avx2_fma(float* outputVector, const __m256 pi_2 = _mm256_set1_ps(0x1.921fb6p0f); const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); const __m256 sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + const __m256 zero = _mm256_setzero_ps(); unsigned int number = 0; unsigned int eighth_points = num_points / 8; @@ -251,6 +258,8 @@ static inline void volk_32fc_s32f_atan2_32f_u_avx2_fma(float* outputVector, _mm256_and_ps(y, abs_mask), _mm256_and_ps(x, abs_mask), _CMP_GT_OS); __m256 input = _mm256_div_ps(_mm256_blendv_ps(y, x, swap_mask), _mm256_blendv_ps(x, y, swap_mask)); + __m256 nan_mask = _mm256_cmp_ps(input, input, _CMP_UNORD_Q); + input = _mm256_blendv_ps(input, zero, nan_mask); __m256 result = _m256_arctan_poly_avx2_fma(input); input = @@ -292,6 +301,7 @@ static inline void volk_32fc_s32f_atan2_32f_u_avx2(float* outputVector, const __m256 pi_2 = _mm256_set1_ps(0x1.921fb6p0f); const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); const __m256 sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + const __m256 zero = _mm256_setzero_ps(); unsigned int number = 0; unsigned int eighth_points = num_points / 8; @@ -308,6 +318,8 @@ static inline void volk_32fc_s32f_atan2_32f_u_avx2(float* outputVector, _mm256_and_ps(y, abs_mask), _mm256_and_ps(x, abs_mask), _CMP_GT_OS); __m256 input = _mm256_div_ps(_mm256_blendv_ps(y, x, swap_mask), _mm256_blendv_ps(x, y, swap_mask)); + __m256 nan_mask = _mm256_cmp_ps(input, input, _CMP_UNORD_Q); + input = _mm256_blendv_ps(input, zero, nan_mask); __m256 result = _m256_arctan_poly_avx(input); input =