Skip to content

Commit 575ee86

Browse files
committed
added avx2 kernels, added inf/0.f/neg masks for all kernels
Signed-off-by: Magnus Lundmark <[email protected]>
1 parent e220082 commit 575ee86

File tree

7 files changed

+264
-119
lines changed

7 files changed

+264
-119
lines changed

include/volk/volk_avx2_intrinsics.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,33 @@
1818
#include "volk/volk_avx_intrinsics.h"
1919
#include <immintrin.h>
2020

21+
/*
22+
* Newton-Raphson refined reciprocal square root: 1/sqrt(a)
23+
* AVX2 version with native 256-bit integer operations for edge case handling
24+
* Handles edge cases: +0 → +Inf, +Inf → 0
25+
*/
26+
static inline __m256 _mm256_rsqrt_nr_avx2(const __m256 a)
27+
{
28+
const __m256 HALF = _mm256_set1_ps(0.5f);
29+
const __m256 THREE_HALFS = _mm256_set1_ps(1.5f);
30+
31+
const __m256 x0 = _mm256_rsqrt_ps(a); // +Inf for +0, 0 for +Inf
32+
33+
// Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34+
__m256 x1 = _mm256_mul_ps(
35+
x0,
36+
_mm256_sub_ps(THREE_HALFS,
37+
_mm256_mul_ps(HALF, _mm256_mul_ps(_mm256_mul_ps(x0, x0), a))));
38+
39+
// For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
40+
// AVX2: native 256-bit integer compare
41+
__m256i a_si = _mm256_castps_si256(a);
42+
__m256i zero_mask = _mm256_cmpeq_epi32(a_si, _mm256_setzero_si256());
43+
__m256i inf_mask = _mm256_cmpeq_epi32(a_si, _mm256_set1_epi32(0x7F800000));
44+
__m256 special_mask = _mm256_castsi256_ps(_mm256_or_si256(zero_mask, inf_mask));
45+
return _mm256_blendv_ps(x1, x0, special_mask);
46+
}
47+
2148
static inline __m256 _mm256_real(const __m256 z1, const __m256 z2)
2249
{
2350
const __m256i permute_mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);

include/volk/volk_avx512_intrinsics.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,26 @@
2020
// Newton-Raphson refined reciprocal square root: 1/sqrt(a)
2121
// One iteration doubles precision from ~12-bit to ~24-bit
2222
// x1 = x0 * (1.5 - 0.5 * a * x0^2)
23+
// Handles edge cases: +0 → +Inf, +Inf → 0
2324
// Requires AVX512F
2425
////////////////////////////////////////////////////////////////////////
2526
static inline __m512 _mm512_rsqrt_nr_ps(const __m512 a)
2627
{
2728
const __m512 HALF = _mm512_set1_ps(0.5f);
2829
const __m512 THREE_HALFS = _mm512_set1_ps(1.5f);
29-
const __m512 x0 = _mm512_rsqrt14_ps(a);
30-
return _mm512_mul_ps(
30+
31+
const __m512 x0 = _mm512_rsqrt14_ps(a); // +Inf for +0, 0 for +Inf
32+
33+
// Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34+
__m512 x1 = _mm512_mul_ps(
3135
x0, _mm512_fnmadd_ps(HALF, _mm512_mul_ps(_mm512_mul_ps(x0, x0), a), THREE_HALFS));
36+
37+
// For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
38+
// Blend: use x0 where a == +0 or a == +Inf, else use x1
39+
__m512i a_si = _mm512_castps_si512(a);
40+
__mmask16 zero_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_setzero_si512());
41+
__mmask16 inf_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_set1_epi32(0x7F800000));
42+
return _mm512_mask_blend_ps(zero_mask | inf_mask, x1, x0);
3243
}
3344

3445
////////////////////////////////////////////////////////////////////////

include/volk/volk_avx_intrinsics.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,37 @@
2121
* Newton-Raphson refined reciprocal square root: 1/sqrt(a)
2222
* One iteration doubles precision from ~12-bit to ~24-bit
2323
* x1 = x0 * (1.5 - 0.5 * a * x0^2)
24+
* Handles edge cases: +0 → +Inf, +Inf → 0
2425
*/
2526
static inline __m256 _mm256_rsqrt_nr_ps(const __m256 a)
2627
{
2728
const __m256 HALF = _mm256_set1_ps(0.5f);
2829
const __m256 THREE_HALFS = _mm256_set1_ps(1.5f);
29-
const __m256 x0 = _mm256_rsqrt_ps(a);
30-
return _mm256_mul_ps(
30+
31+
const __m256 x0 = _mm256_rsqrt_ps(a); // +Inf for +0, 0 for +Inf
32+
33+
// Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34+
__m256 x1 = _mm256_mul_ps(
3135
x0,
3236
_mm256_sub_ps(THREE_HALFS,
3337
_mm256_mul_ps(HALF, _mm256_mul_ps(_mm256_mul_ps(x0, x0), a))));
38+
39+
// For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
40+
// Blend: use x0 where a == +0 or a == +Inf, else use x1
41+
// AVX-only: use SSE2 integer compare, then reconstruct AVX mask
42+
__m128i a_lo = _mm256_castsi256_si128(_mm256_castps_si256(a));
43+
__m128i a_hi = _mm_castps_si128(_mm256_extractf128_ps(a, 1));
44+
__m128i zero_si = _mm_setzero_si128();
45+
__m128i inf_si = _mm_set1_epi32(0x7F800000);
46+
__m128i zero_mask_lo = _mm_cmpeq_epi32(a_lo, zero_si);
47+
__m128i zero_mask_hi = _mm_cmpeq_epi32(a_hi, zero_si);
48+
__m128i inf_mask_lo = _mm_cmpeq_epi32(a_lo, inf_si);
49+
__m128i inf_mask_hi = _mm_cmpeq_epi32(a_hi, inf_si);
50+
__m128 mask_lo = _mm_castsi128_ps(_mm_or_si128(zero_mask_lo, inf_mask_lo));
51+
__m128 mask_hi = _mm_castsi128_ps(_mm_or_si128(zero_mask_hi, inf_mask_hi));
52+
__m256 special_mask =
53+
_mm256_insertf128_ps(_mm256_castps128_ps256(mask_lo), mask_hi, 1);
54+
return _mm256_blendv_ps(x1, x0, special_mask);
3455
}
3556

3657
/*

include/volk/volk_neon_intrinsics.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,23 @@ static inline float32x4_t _vmagnitudesquaredq_f32(float32x4x2_t cmplxValue)
8080
return result;
8181
}
8282

83-
/* Inverse square root for float32x4_t */
83+
/* Inverse square root for float32x4_t
84+
* Handles edge cases: +0 → +Inf, +Inf → 0 */
8485
static inline float32x4_t _vinvsqrtq_f32(float32x4_t x)
8586
{
86-
float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
87-
sqrt_reciprocal = vmulq_f32(
88-
vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal);
89-
sqrt_reciprocal = vmulq_f32(
90-
vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal);
91-
92-
return sqrt_reciprocal;
87+
float32x4_t x0 = vrsqrteq_f32(x); // +Inf for +0, 0 for +Inf
88+
89+
// Newton-Raphson refinement using vrsqrtsq_f32
90+
float32x4_t x1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, x0), x0), x0);
91+
x1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, x1), x1), x1);
92+
93+
// For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
94+
// Blend: use x0 where x == +0 or x == +Inf, else use x1
95+
uint32x4_t x_bits = vreinterpretq_u32_f32(x);
96+
uint32x4_t zero_mask = vceqq_u32(x_bits, vdupq_n_u32(0x00000000));
97+
uint32x4_t inf_mask = vceqq_u32(x_bits, vdupq_n_u32(0x7F800000));
98+
uint32x4_t special_mask = vorrq_u32(zero_mask, inf_mask);
99+
return vbslq_f32(special_mask, x0, x1);
93100
}
94101

95102
/* Square root for ARMv7 NEON (no vsqrtq_f32)

include/volk/volk_sse_intrinsics.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,34 @@
1515

1616
#ifndef INCLUDE_VOLK_VOLK_SSE_INTRINSICS_H_
1717
#define INCLUDE_VOLK_VOLK_SSE_INTRINSICS_H_
18+
#include <emmintrin.h>
1819
#include <xmmintrin.h>
1920

2021
/*
2122
* Newton-Raphson refined reciprocal square root: 1/sqrt(a)
2223
* One iteration doubles precision from ~12-bit to ~24-bit
2324
* x1 = x0 * (1.5 - 0.5 * a * x0^2)
25+
* Handles edge cases: +0 → +Inf, +Inf → 0
2426
*/
2527
static inline __m128 _mm_rsqrt_nr_ps(const __m128 a)
2628
{
2729
const __m128 HALF = _mm_set1_ps(0.5f);
2830
const __m128 THREE_HALFS = _mm_set1_ps(1.5f);
29-
const __m128 x0 = _mm_rsqrt_ps(a);
30-
return _mm_mul_ps(
31+
32+
const __m128 x0 = _mm_rsqrt_ps(a); // +Inf for +0, 0 for +Inf
33+
34+
// Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
35+
__m128 x1 = _mm_mul_ps(
3136
x0, _mm_sub_ps(THREE_HALFS, _mm_mul_ps(HALF, _mm_mul_ps(_mm_mul_ps(x0, x0), a))));
37+
38+
// For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
39+
// Blend: use x0 where a == +0 or a == +Inf, else use x1
40+
__m128i a_si = _mm_castps_si128(a);
41+
__m128i zero_mask = _mm_cmpeq_epi32(a_si, _mm_setzero_si128());
42+
__m128i inf_mask = _mm_cmpeq_epi32(a_si, _mm_set1_epi32(0x7F800000));
43+
__m128 special_mask = _mm_castsi128_ps(_mm_or_si128(zero_mask, inf_mask));
44+
// SSE2-compatible blend: (x0 & mask) | (x1 & ~mask)
45+
return _mm_or_ps(_mm_and_ps(special_mask, x0), _mm_andnot_ps(special_mask, x1));
3246
}
3347

3448
/*

0 commit comments

Comments
 (0)