@@ -2,7 +2,7 @@ use core::arch::aarch64::*;
22use core:: iter:: zip;
33use core:: mem;
44
5- use crate :: danger :: { DenseLane , SimdRegister } ;
5+ use super :: core_simd_api :: { DenseLane , SimdRegister , Hypot } ;
66use crate :: math:: { AutoMath , Math } ;
77
88const BITS_8_CAPACITY : usize = 16 ;
@@ -146,11 +146,48 @@ impl SimdRegister<f32> for Neon {
146146 vst1q_f32 ( mem, reg)
147147 }
148148}
149-
149+ const EXPONENT_MASK_F32 : u32 = 2139095040 ;
150+ const MANTISSA_MASK_F32 : u32 = 8388607 ;
150151impl Hypot < f32 > for Neon {
151152 #[ inline( always) ]
152- unsafe fn hypot ( l1 : Self :: Register , l2 : Self :: Register ) -> Self :: Register {
153- todo ! ( )
153+ unsafe fn hypot ( x : Self :: Register , y : Self :: Register ) -> Self :: Register {
154+ // Convert inputs to absolute values
155+ let ( x, y) = ( vabsq_f32 ( x) , vabsq_f32 ( y) ) ;
156+
157+ // Find the max and min of the two inputs
158+ let ( hi, lo) = ( vmaxq_f32 ( x, y) , vminq_f32 ( x, y) ) ;
159+ let exponent_mask = vdupq_n_u32 ( EXPONENT_MASK_F32 ) ;
160+ let mantissa_mask = vdupq_n_u32 ( MANTISSA_MASK_F32 ) ;
161+
162+ // round the hi values down to the nearest power of 2
163+ let hi2p =
164+ vreinterpretq_f32_u32 ( vandq_u32 ( vreinterpretq_u32_f32 ( hi) , exponent_mask) ) ;
165+ // we scale the values inside the root by the reciprocal of hi2p. since it's a power of 2,
166+ // we can double it and xor it with the exponent mask
167+ let scale = vreinterpretq_f32_u32 ( veorq_u32 (
168+ vreinterpretq_u32_f32 ( vaddq_f32 ( hi2p, hi2p) ) ,
169+ exponent_mask,
170+ ) ) ;
171+ // create a mask that matches the normal hi values
172+ let mask = vcgtq_f32 ( hi, vdupq_n_f32 ( f32:: MIN_POSITIVE ) ) ;
173+ // replace the subnormal values of hi2p with the minimum positive normal value
174+ let hi2p = vbslq_f32 ( mask, hi2p, vdupq_n_f32 ( f32:: MIN_POSITIVE ) ) ;
175+ // replace the subnormal values of scale with the reciprocal of the minimum positive normal value
176+ let scale = vbslq_f32 ( mask, scale, vdupq_n_f32 ( 1.0 / f32:: MIN_POSITIVE ) ) ;
177+ // create a mask that matches the subnormal hi values
178+ let mask = vcltq_f32 ( hi, vdupq_n_f32 ( f32:: MIN_POSITIVE ) ) ;
179+ // since hi2p was preserved the exponent bits of hi, the exponent of hi/hi2p is 1
180+ let hi_scaled = vreinterpretq_f32_u32 ( vorrq_u32 (
181+ vandq_u32 ( vreinterpretq_u32_f32 ( hi) , mantissa_mask) ,
182+ vreinterpretq_u32_f32 ( vdupq_n_f32 ( 1.0 ) ) ,
183+ ) ) ;
184+ // for the subnormal elements of hi, we need to subtract 1 from the scaled hi values
185+ let hi_scaled =
186+ vbslq_f32 ( mask, vsubq_f32 ( hi_scaled, vdupq_n_f32 ( 1.0 ) ) , hi_scaled) ;
187+ // finally, do the thing
188+ let hi_scaled = vmulq_f32 ( hi_scaled, hi_scaled) ;
189+ let lo_scaled = vmulq_f32 ( lo, scale) ;
190+ vmulq_f32 ( hi2p, vsqrtq_f32 ( vfmaq_f32 ( lo_scaled, lo_scaled, hi_scaled) ) )
154191 }
155192}
156193
@@ -286,6 +323,49 @@ impl SimdRegister<f64> for Neon {
286323 }
287324}
288325
326+ impl Hypot < f64 > for Neon {
327+ #[ inline( always) ]
328+ unsafe fn hypot ( x : Self :: Register , y : Self :: Register ) -> Self :: Register {
329+ // Convert inputs to absolute values
330+ let ( x, y) = ( vabsq_f64 ( x) , vabsq_f64 ( y) ) ;
331+
332+ // Find the max and min of the two inputs
333+ let ( hi, lo) = ( vmaxq_f64 ( x, y) , vminq_f64 ( x, y) ) ;
334+ let exponent_mask = vdupq_n_u64 ( f64:: INFINITY . to_bits ( ) ) ;
335+ let mantissa_mask = vdupq_n_u64 ( ( f64:: MIN_POSITIVE - mem:: transmute :: < u64 , f64 > ( 1 ) ) . to_bits ( ) ) ;
336+
337+ // round the hi values down to the nearest power of 2
338+ let hi2p =
339+ vreinterpretq_f64_u64 ( vandq_u64 ( vreinterpretq_u64_f64 ( hi) , exponent_mask) ) ;
340+ // we scale the values inside the root by the reciprocal of hi2p. since it's a power of 2,
341+ // we can double it and xor it with the exponent mask
342+ let scale = vreinterpretq_f64_u64 ( veorq_u64 (
343+ vreinterpretq_u64_f64 ( vaddq_f64 ( hi2p, hi2p) ) ,
344+ exponent_mask,
345+ ) ) ;
346+ // create a mask that matches the normal hi values
347+ let mask = vcgtq_f64 ( hi, vdupq_n_f64 ( f64:: MIN_POSITIVE ) ) ;
348+ // replace the subnormal values of hi2p with the minimum positive normal value
349+ let hi2p = vbslq_f64 ( mask, hi2p, vdupq_n_f64 ( f64:: MIN_POSITIVE ) ) ;
350+ // replace the subnormal values of scale with the reciprocal of the minimum positive normal value
351+ let scale = vbslq_f64 ( mask, scale, vdupq_n_f64 ( 1.0 / f64:: MIN_POSITIVE ) ) ;
352+ // create a mask that matches the subnormal hi values
353+ let mask = vcltq_f64 ( hi, vdupq_n_f64 ( f64:: MIN_POSITIVE ) ) ;
354+ // since hi2p was preserved the exponent bits of hi, the exponent of hi/hi2p is 1
355+ let hi_scaled = vreinterpretq_f64_u64 ( vorrq_u64 (
356+ vandq_u64 ( vreinterpretq_u64_f64 ( hi) , mantissa_mask) ,
357+ vreinterpretq_u64_f64 ( vdupq_n_f64 ( 1.0 ) ) ,
358+ ) ) ;
359+ // for the subnormal elements of hi, we need to subtract 1 from the scaled hi values
360+ let hi_scaled =
361+ vbslq_f64 ( mask, vsubq_f64 ( hi_scaled, vdupq_n_f64 ( 1.0 ) ) , hi_scaled) ;
362+ // finally, do the thing
363+ let hi_scaled = vmulq_f64 ( hi_scaled, hi_scaled) ;
364+ let lo_scaled = vmulq_f64 ( lo, scale) ;
365+ vmulq_f64 ( hi2p, vsqrtq_f64 ( vfmaq_f64 ( lo_scaled, lo_scaled, hi_scaled) ) )
366+ }
367+ }
368+
289369impl SimdRegister < i8 > for Neon {
290370 type Register = int8x16_t ;
291371
0 commit comments