Skip to content

Commit 11db444

Browse files
committed
use sqrt intrinsic for fastmath, implemented Hypot for Neon
1 parent a343307 commit 11db444

File tree

3 files changed

+108
-8
lines changed

3 files changed

+108
-8
lines changed

cfavml/src/danger/impl_neon.rs

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use core::arch::aarch64::*;
22
use core::iter::zip;
33
use core::mem;
44

5-
use crate::danger::{DenseLane, SimdRegister};
5+
use super::core_simd_api::{DenseLane, SimdRegister,Hypot};
66
use crate::math::{AutoMath, Math};
77

88
const BITS_8_CAPACITY: usize = 16;
@@ -146,11 +146,48 @@ impl SimdRegister<f32> for Neon {
146146
vst1q_f32(mem, reg)
147147
}
148148
}
149-
149+
const EXPONENT_MASK_F32: u32 = 2139095040;
150+
const MANTISSA_MASK_F32: u32 = 8388607;
150151
impl Hypot<f32> for Neon {
151152
#[inline(always)]
152-
unsafe fn hypot(l1: Self::Register, l2: Self::Register) -> Self::Register {
153-
todo!()
153+
unsafe fn hypot(x: Self::Register, y: Self::Register) -> Self::Register {
154+
// Convert inputs to absolute values
155+
let (x, y) = (vabsq_f32(x), vabsq_f32(y));
156+
157+
// Find the max and min of the two inputs
158+
let (hi, lo) = (vmaxq_f32(x, y), vminq_f32(x, y));
159+
let exponent_mask = vdupq_n_u32(EXPONENT_MASK_F32);
160+
let mantissa_mask = vdupq_n_u32(MANTISSA_MASK_F32);
161+
162+
// round the hi values down to the nearest power of 2
163+
let hi2p =
164+
vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(hi), exponent_mask));
165+
// we scale the values inside the root by the reciprocal of hi2p. since it's a power of 2,
166+
// we can double it and xor it with the exponent mask
167+
let scale = vreinterpretq_f32_u32(veorq_u32(
168+
vreinterpretq_u32_f32(vaddq_f32(hi2p, hi2p)),
169+
exponent_mask,
170+
));
171+
// create a mask that matches the normal hi values
172+
let mask = vcgtq_f32(hi, vdupq_n_f32(f32::MIN_POSITIVE));
173+
// replace the subnormal values of hi2p with the minimum positive normal value
174+
let hi2p = vbslq_f32(mask, hi2p, vdupq_n_f32(f32::MIN_POSITIVE));
175+
// replace the subnormal values of scale with the reciprocal of the minimum positive normal value
176+
let scale = vbslq_f32(mask, scale, vdupq_n_f32(1.0 / f32::MIN_POSITIVE));
177+
// create a mask that matches the subnormal hi values
178+
let mask = vcltq_f32(hi, vdupq_n_f32(f32::MIN_POSITIVE));
179+
// since hi2p was preserved the exponent bits of hi, the exponent of hi/hi2p is 1
180+
let hi_scaled = vreinterpretq_f32_u32(vorrq_u32(
181+
vandq_u32(vreinterpretq_u32_f32(hi), mantissa_mask),
182+
vreinterpretq_u32_f32(vdupq_n_f32(1.0)),
183+
));
184+
// for the subnormal elements of hi, we need to subtract 1 from the scaled hi values
185+
let hi_scaled =
186+
vbslq_f32(mask, vsubq_f32(hi_scaled, vdupq_n_f32(1.0)), hi_scaled);
187+
// finally, do the thing
188+
let hi_scaled = vmulq_f32(hi_scaled, hi_scaled);
189+
let lo_scaled = vmulq_f32(lo, scale);
190+
vmulq_f32(hi2p, vsqrtq_f32(vfmaq_f32(lo_scaled, lo_scaled, hi_scaled)))
154191
}
155192
}
156193

@@ -286,6 +323,49 @@ impl SimdRegister<f64> for Neon {
286323
}
287324
}
288325

326+
impl Hypot<f64> for Neon {
327+
#[inline(always)]
328+
unsafe fn hypot(x: Self::Register, y: Self::Register) -> Self::Register {
329+
// Convert inputs to absolute values
330+
let (x, y) = (vabsq_f64(x), vabsq_f64(y));
331+
332+
// Find the max and min of the two inputs
333+
let (hi, lo) = (vmaxq_f64(x, y), vminq_f64(x, y));
334+
let exponent_mask = vdupq_n_u64(f64::INFINITY.to_bits());
335+
let mantissa_mask = vdupq_n_u64((f64::MIN_POSITIVE - mem::transmute::<u64,f64>(1)).to_bits());
336+
337+
// round the hi values down to the nearest power of 2
338+
let hi2p =
339+
vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(hi), exponent_mask));
340+
// we scale the values inside the root by the reciprocal of hi2p. since it's a power of 2,
341+
// we can double it and xor it with the exponent mask
342+
let scale = vreinterpretq_f64_u64(veorq_u64(
343+
vreinterpretq_u64_f64(vaddq_f64(hi2p, hi2p)),
344+
exponent_mask,
345+
));
346+
// create a mask that matches the normal hi values
347+
let mask = vcgtq_f64(hi, vdupq_n_f64(f64::MIN_POSITIVE));
348+
// replace the subnormal values of hi2p with the minimum positive normal value
349+
let hi2p = vbslq_f64(mask, hi2p, vdupq_n_f64(f64::MIN_POSITIVE));
350+
// replace the subnormal values of scale with the reciprocal of the minimum positive normal value
351+
let scale = vbslq_f64(mask, scale, vdupq_n_f64(1.0 / f64::MIN_POSITIVE));
352+
// create a mask that matches the subnormal hi values
353+
let mask = vcltq_f64(hi, vdupq_n_f64(f64::MIN_POSITIVE));
354+
// since hi2p was preserved the exponent bits of hi, the exponent of hi/hi2p is 1
355+
let hi_scaled = vreinterpretq_f64_u64(vorrq_u64(
356+
vandq_u64(vreinterpretq_u64_f64(hi), mantissa_mask),
357+
vreinterpretq_u64_f64(vdupq_n_f64(1.0)),
358+
));
359+
// for the subnormal elements of hi, we need to subtract 1 from the scaled hi values
360+
let hi_scaled =
361+
vbslq_f64(mask, vsubq_f64(hi_scaled, vdupq_n_f64(1.0)), hi_scaled);
362+
// finally, do the thing
363+
let hi_scaled = vmulq_f64(hi_scaled, hi_scaled);
364+
let lo_scaled = vmulq_f64(lo, scale);
365+
vmulq_f64(hi2p, vsqrtq_f64(vfmaq_f64(lo_scaled, lo_scaled, hi_scaled)))
366+
}
367+
}
368+
289369
impl SimdRegister<i8> for Neon {
290370
type Register = int8x16_t;
291371

cfavml/src/danger/impl_test.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ unsafe fn test_sample<T, R>(
326326
Standard: Distribution<T>,
327327
{
328328
{
329-
let (_std_result, std_sum) = get_std_results(&sample1, &sample2);
329+
let (std_result, std_sum) = get_std_results(&sample1, &sample2);
330330
let l1 = R::load(sample1.as_ptr());
331331
let l2 = R::load(sample2.as_ptr());
332332
let res = R::hypot(l1, l2);
@@ -335,9 +335,12 @@ unsafe fn test_sample<T, R>(
335335
AutoMath::is_close(std_sum, res_sum),
336336
"Hypot and sum test failed on single task"
337337
);
338+
let mut res_vec = vec![T::zero(); R::elements_per_lane()];
339+
R::write(res_vec.as_mut_ptr(), res);
340+
test_diff_ulps(std_result, res_vec);
338341
}
339342
{
340-
let (_std_result, std_sum) = get_std_results(&large_sample_l1, &large_sample_l2);
343+
let (std_result, std_sum) = get_std_results(&large_sample_l1, &large_sample_l2);
341344
let l1 = R::load_dense(large_sample_l1.as_ptr());
342345
let l2 = R::load_dense(large_sample_l2.as_ptr());
343346
let res = R::hypot_dense(l1, l2);
@@ -348,6 +351,9 @@ unsafe fn test_sample<T, R>(
348351
AutoMath::is_close(std_sum, res_sum),
349352
"Hypot and sum test failed on dense task"
350353
);
354+
let mut res_vec = vec![T::zero(); R::elements_per_dense()];
355+
R::write_dense(res_vec.as_mut_ptr(), res);
356+
test_diff_ulps(std_result, res_vec);
351357
}
352358
}
353359

@@ -364,3 +370,17 @@ where
364370
let sum = std_result.iter().fold(AutoMath::zero(), |a, b| a + *b);
365371
(std_result, sum)
366372
}
373+
374+
fn test_diff_ulps<T>(a: Vec<T>, b: Vec<T>)
375+
where
376+
T: Float + FloatConst,
377+
{
378+
a.iter().zip(b.iter()).for_each(|(a, b)| {
379+
let (a_mant, a_exp, a_sign) = a.integer_decode();
380+
let (b_mant, b_exp, b_sign) = b.integer_decode();
381+
assert!(a_sign == b_sign);
382+
assert!(a_exp == b_exp);
383+
let dist = a_mant as i64 - b_mant as i64;
384+
assert!(dist.abs() < 2, "Greater than 1 ulp difference: {dist}");
385+
});
386+
}

cfavml/src/math/fast_math.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl Math<f32> for FastMath {
2828

2929
#[inline(always)]
3030
fn sqrt(a: f32) -> f32 {
31-
StdMath::sqrt(a)
31+
core::intrinsics::sqrtf32(a)
3232
}
3333

3434
#[inline(always)]
@@ -139,7 +139,7 @@ impl Math<f64> for FastMath {
139139

140140
#[inline(always)]
141141
fn sqrt(a: f64) -> f64 {
142-
StdMath::sqrt(a)
142+
core::intrinsics::sqrtf64(a)
143143
}
144144

145145
#[inline(always)]

0 commit comments

Comments
 (0)