|
1 | 1 | /* -*- c++ -*- */ |
2 | 2 | /* |
3 | 3 | * Copyright 2015 Free Software Foundation, Inc. |
4 | | - * Copyright 2023 Magnus Lundmark <[email protected]> |
| 4 | + * Copyright 2023-2026 Magnus Lundmark <[email protected]> |
5 | 5 | * |
6 | 6 | * This file is part of VOLK |
7 | 7 | * |
|
17 | 17 | #define INCLUDE_VOLK_VOLK_AVX_INTRINSICS_H_ |
18 | 18 | #include <immintrin.h> |
19 | 19 |
|
| 20 | +/* |
| 21 | + * Newton-Raphson refined reciprocal square root: 1/sqrt(a) |
| 22 | + * One iteration doubles precision from ~12-bit to ~24-bit |
| 23 | + * x1 = x0 * (1.5 - 0.5 * a * x0^2) |
| 24 | + * Handles edge cases: +0 → +Inf, +Inf → 0 |
| 25 | + */ |
| 26 | +static inline __m256 _mm256_rsqrt_nr_ps(const __m256 a) |
| 27 | +{ |
| 28 | + const __m256 HALF = _mm256_set1_ps(0.5f); |
| 29 | + const __m256 THREE_HALFS = _mm256_set1_ps(1.5f); |
| 30 | + |
| 31 | + const __m256 x0 = _mm256_rsqrt_ps(a); // +Inf for +0, 0 for +Inf |
| 32 | + |
| 33 | + // Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2) |
| 34 | + __m256 x1 = _mm256_mul_ps( |
| 35 | + x0, |
| 36 | + _mm256_sub_ps(THREE_HALFS, |
| 37 | + _mm256_mul_ps(HALF, _mm256_mul_ps(_mm256_mul_ps(x0, x0), a)))); |
| 38 | + |
| 39 | + // For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0 |
| 40 | + // Blend: use x0 where a == +0 or a == +Inf, else use x1 |
| 41 | + // AVX-only: use SSE2 integer compare, then reconstruct AVX mask |
| 42 | + __m128i a_lo = _mm256_castsi256_si128(_mm256_castps_si256(a)); |
| 43 | + __m128i a_hi = _mm_castps_si128(_mm256_extractf128_ps(a, 1)); |
| 44 | + __m128i zero_si = _mm_setzero_si128(); |
| 45 | + __m128i inf_si = _mm_set1_epi32(0x7F800000); |
| 46 | + __m128i zero_mask_lo = _mm_cmpeq_epi32(a_lo, zero_si); |
| 47 | + __m128i zero_mask_hi = _mm_cmpeq_epi32(a_hi, zero_si); |
| 48 | + __m128i inf_mask_lo = _mm_cmpeq_epi32(a_lo, inf_si); |
| 49 | + __m128i inf_mask_hi = _mm_cmpeq_epi32(a_hi, inf_si); |
| 50 | + __m128 mask_lo = _mm_castsi128_ps(_mm_or_si128(zero_mask_lo, inf_mask_lo)); |
| 51 | + __m128 mask_hi = _mm_castsi128_ps(_mm_or_si128(zero_mask_hi, inf_mask_hi)); |
| 52 | + __m256 special_mask = |
| 53 | + _mm256_insertf128_ps(_mm256_castps128_ps256(mask_lo), mask_hi, 1); |
| 54 | + return _mm256_blendv_ps(x1, x0, special_mask); |
| 55 | +} |
| 56 | + |
20 | 57 | /* |
21 | 58 | * Approximate arctan(x) via polynomial expansion |
22 | 59 | * on the interval [-1, 1] |
|
0 commit comments