Skip to content

Commit 8791a32

Browse files
committed
Prefer shuffle for float and double
1 parent b0ac87c commit 8791a32

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

flatbush.h

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -357,20 +357,27 @@ inline double computeDistanceSquared(const Point<ArrayType>& iPoint,
357357
}
358358

359359
#if defined(FLATBUSH_USE_SIMD)
360-
static const auto kZeroPd = _mm_setzero_pd();
361-
static const auto kZeroPs = _mm_setzero_ps();
360+
static constexpr auto kShuffleUnpackLo = _MM_SHUFFLE(1, 0, 1, 0);
361+
static constexpr auto kShuffleUnpackHi = _MM_SHUFFLE(3, 2, 3, 2);
362+
static constexpr auto kShuffleBroadcast0 = _MM_SHUFFLE(0, 0, 0, 0);
363+
static constexpr auto kShuffleBroadcast1 = _MM_SHUFFLE(1, 1, 1, 1);
364+
static constexpr auto kShuffleMergeMinMax = _MM_SHUFFLE(3, 2, 1, 0);
365+
static constexpr auto kShuffleExchange01 = _MM_SHUFFLE2(0, 1);
362366
static const auto kOffset8 = _mm_set1_epi8(std::numeric_limits<int8_t>::min());
363367
static const auto kOffset16 = _mm_set1_epi16(std::numeric_limits<int16_t>::min());
364368
static const auto kOffset32 = _mm_set1_epi32(std::numeric_limits<int32_t>::min());
365369
static const auto kShuffleMin =
366370
_mm_setr_epi8(0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
367371
static const auto kShuffleMax =
368372
_mm_setr_epi8(2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
373+
static const auto kZeroPd = _mm_setzero_pd();
374+
static const auto kZeroPs = _mm_setzero_ps();
375+
376+
static const auto kMaskAllOnes = _mm_set1_epi32(0xFFFF);
369377
static const auto kMaskInterleave1 = _mm_set1_epi32(0x00FF00FF);
370378
static const auto kMaskInterleave2 = _mm_set1_epi32(0x0F0F0F0F);
371379
static const auto kMaskInterleave3 = _mm_set1_epi32(0x33333333);
372380
static const auto kMaskInterleave4 = _mm_set1_epi32(0x55555555);
373-
static const auto kMaskAllOnes = _mm_set1_epi32(0xFFFF);
374381

375382
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_AVX512
376383
static const auto kShuffleMinX512 = _mm512_setr_epi64(0, 4, 8, 12, 0, 0, 0, 0);
@@ -459,8 +466,8 @@ template <>
459466
inline bool boxesIntersect<float>(const Box<float>& iQuery, const Box<float>& iBox) noexcept {
460467
const auto wQuery = _mm_loadu_ps(&iQuery.mMinX);
461468
const auto wBox = _mm_loadu_ps(&iBox.mMinX);
462-
const auto wMin = _mm_unpacklo_ps(wQuery, wBox);
463-
const auto wMax = _mm_unpackhi_ps(wBox, wQuery);
469+
const auto wMin = _mm_shuffle_ps(wQuery, wBox, kShuffleUnpackLo);
470+
const auto wMax = _mm_shuffle_ps(wBox, wQuery, kShuffleUnpackHi);
464471
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_AVX512
465472
return _mm_cmp_ps_mask(wMax, wMin, _CMP_LT_OQ) == 0;
466473
#elif FLATBUSH_USE_SIMD >= FLATBUSH_USE_AVX
@@ -504,10 +511,10 @@ inline bool boxesIntersect<int8_t>(const Box<int8_t>& iQuery, const Box<int8_t>&
504511
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi8(wBox, kShuffleMax),
505512
_mm_shuffle_epi8(wQuery, kShuffleMax));
506513
#else
507-
const auto wMin = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wQuery, _MM_SHUFFLE(0, 0, 0, 0)),
508-
_mm_shufflelo_epi16(wBox, _MM_SHUFFLE(0, 0, 0, 0)));
509-
const auto wMax = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wBox, _MM_SHUFFLE(1, 1, 1, 1)),
510-
_mm_shufflelo_epi16(wQuery, _MM_SHUFFLE(1, 1, 1, 1)));
514+
const auto wMin = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wQuery, kShuffleBroadcast0),
515+
_mm_shufflelo_epi16(wBox, kShuffleBroadcast0));
516+
const auto wMax = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wBox, kShuffleBroadcast1),
517+
_mm_shufflelo_epi16(wQuery, kShuffleBroadcast1));
511518
#endif
512519
const auto wCmp = _mm_cmplt_epi8(wMax, wMin);
513520
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
@@ -527,10 +534,10 @@ inline bool boxesIntersect<uint8_t>(const Box<uint8_t>& iQuery, const Box<uint8_
527534
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi8(wBox, kShuffleMax),
528535
_mm_shuffle_epi8(wQuery, kShuffleMax));
529536
#else
530-
const auto wMin = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wQuery, _MM_SHUFFLE(0, 0, 0, 0)),
531-
_mm_shufflelo_epi16(wBox, _MM_SHUFFLE(0, 0, 0, 0)));
532-
const auto wMax = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wBox, _MM_SHUFFLE(1, 1, 1, 1)),
533-
_mm_shufflelo_epi16(wQuery, _MM_SHUFFLE(1, 1, 1, 1)));
537+
const auto wMin = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wQuery, kShuffleBroadcast0),
538+
_mm_shufflelo_epi16(wBox, kShuffleBroadcast0));
539+
const auto wMax = _mm_unpacklo_epi8(_mm_shufflelo_epi16(wBox, kShuffleBroadcast1),
540+
_mm_shufflelo_epi16(wQuery, kShuffleBroadcast1));
534541
#endif
535542
const auto wCmp = _mm_cmplt_epi8(_mm_add_epi8(wMax, kOffset8), _mm_add_epi8(wMin, kOffset8));
536543
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
@@ -544,10 +551,10 @@ template <>
544551
inline bool boxesIntersect<int16_t>(const Box<int16_t>& iQuery, const Box<int16_t>& iBox) noexcept {
545552
const auto wQuery = _mm_loadu_si64(&iQuery.mMinX);
546553
const auto wBox = _mm_loadu_si64(&iBox.mMinX);
547-
const auto wMin = _mm_unpacklo_epi16(_mm_shuffle_epi32(wQuery, _MM_SHUFFLE(0, 0, 0, 0)),
548-
_mm_shuffle_epi32(wBox, _MM_SHUFFLE(0, 0, 0, 0)));
549-
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi32(wBox, _MM_SHUFFLE(1, 1, 1, 1)),
550-
_mm_shuffle_epi32(wQuery, _MM_SHUFFLE(1, 1, 1, 1)));
554+
const auto wMin = _mm_unpacklo_epi16(_mm_shuffle_epi32(wQuery, kShuffleBroadcast0),
555+
_mm_shuffle_epi32(wBox, kShuffleBroadcast0));
556+
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi32(wBox, kShuffleBroadcast1),
557+
_mm_shuffle_epi32(wQuery, kShuffleBroadcast1));
551558
const auto wCmp = _mm_cmplt_epi16(wMax, wMin);
552559
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
553560
return _mm_testz_si128(wCmp, wCmp);
@@ -561,10 +568,10 @@ inline bool boxesIntersect<uint16_t>(const Box<uint16_t>& iQuery,
561568
const Box<uint16_t>& iBox) noexcept {
562569
const auto wQuery = _mm_loadu_si64(&iQuery.mMinX);
563570
const auto wBox = _mm_loadu_si64(&iBox.mMinX);
564-
const auto wMin = _mm_unpacklo_epi16(_mm_shuffle_epi32(wQuery, _MM_SHUFFLE(0, 0, 0, 0)),
565-
_mm_shuffle_epi32(wBox, _MM_SHUFFLE(0, 0, 0, 0)));
566-
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi32(wBox, _MM_SHUFFLE(1, 1, 1, 1)),
567-
_mm_shuffle_epi32(wQuery, _MM_SHUFFLE(1, 1, 1, 1)));
571+
const auto wMin = _mm_unpacklo_epi16(_mm_shuffle_epi32(wQuery, kShuffleBroadcast0),
572+
_mm_shuffle_epi32(wBox, kShuffleBroadcast0));
573+
const auto wMax = _mm_unpacklo_epi16(_mm_shuffle_epi32(wBox, kShuffleBroadcast1),
574+
_mm_shuffle_epi32(wQuery, kShuffleBroadcast1));
568575
const auto wCmp = _mm_cmplt_epi16(_mm_add_epi16(wMax, kOffset16), _mm_add_epi16(wMin, kOffset16));
569576
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
570577
return _mm_testz_si128(wCmp, wCmp);
@@ -611,7 +618,7 @@ inline void updateBounds<float>(Box<float>& ioSrc, const Box<float>& iBox) noexc
611618
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
612619
_mm_storeu_ps(&ioSrc.mMinX, _mm_blend_ps(wMins, wMaxs, 0xC));
613620
#else
614-
_mm_storeu_ps(&ioSrc.mMinX, _mm_shuffle_ps(wMins, wMaxs, _MM_SHUFFLE(3, 2, 1, 0)));
621+
_mm_storeu_ps(&ioSrc.mMinX, _mm_shuffle_ps(wMins, wMaxs, kShuffleMergeMinMax));
615622
#endif
616623
}
617624

@@ -740,7 +747,7 @@ inline double computeDistanceSquared<double>(const Point<double>& iPoint,
740747
const auto wResult = _mm_hadd_pd(wDistSq, wDistSq);
741748
#else
742749
const auto wDistSq = _mm_mul_pd(wDist, wDist);
743-
const auto wResult = _mm_add_pd(wDistSq, _mm_unpackhi_pd(wDistSq, wDistSq));
750+
const auto wResult = _mm_add_pd(wDistSq, _mm_shuffle_pd(wDistSq, wDistSq, kShuffleExchange01));
744751
#endif
745752
return _mm_cvtsd_f64(wResult);
746753
}
@@ -749,13 +756,13 @@ template <>
749756
inline double computeDistanceSquared<float>(const Point<float>& iPoint,
750757
const Box<float>& iBox) noexcept {
751758
const auto wPoint = _mm_castpd_ps(_mm_load_sd(bit_cast<const double*>(&iPoint.mX)));
752-
const auto wPointHl = _mm_movelh_ps(wPoint, wPoint);
759+
const auto wPoint2 = _mm_shuffle_ps(wPoint, wPoint, kShuffleUnpackLo);
753760
const auto wBox = _mm_loadu_ps(&iBox.mMinX);
754-
const auto wBoxMin = _mm_movelh_ps(wBox, wBox);
755-
const auto wBoxMax = _mm_movehl_ps(wBox, wBox);
761+
const auto wBoxMin = _mm_shuffle_ps(wBox, wBox, kShuffleUnpackLo);
762+
const auto wBoxMax = _mm_shuffle_ps(wBox, wBox, kShuffleUnpackHi);
756763
// Compute axis distances - using max to clamp to zero
757764
const auto wDist =
758-
_mm_max_ps(kZeroPs, _mm_max_ps(_mm_sub_ps(wBoxMin, wPointHl), _mm_sub_ps(wPointHl, wBoxMax)));
765+
_mm_max_ps(kZeroPs, _mm_max_ps(_mm_sub_ps(wBoxMin, wPoint2), _mm_sub_ps(wPoint2, wBoxMax)));
759766
// Square and sum
760767
#if FLATBUSH_USE_SIMD >= FLATBUSH_USE_SSE4
761768
const auto wResult = _mm_dp_ps(wDist, wDist, 0x31);
@@ -764,8 +771,7 @@ inline double computeDistanceSquared<float>(const Point<float>& iPoint,
764771
const auto wResult = _mm_hadd_ps(wDistSq, wDistSq);
765772
#else
766773
const auto wDistSq = _mm_mul_ps(wDist, wDist);
767-
const auto wShuf = _mm_shuffle_ps(wDistSq, wDistSq, _MM_SHUFFLE(1, 1, 1, 1));
768-
const auto wResult = _mm_add_ps(wDistSq, wShuf);
774+
const auto wResult = _mm_add_ps(wDistSq, _mm_shuffle_ps(wDistSq, wDistSq, kShuffleBroadcast1));
769775
#endif
770776
return static_cast<double>(_mm_cvtss_f32(wResult));
771777
}
@@ -1070,10 +1076,10 @@ inline std::vector<uint32_t> computeHilbertValues<double>(size_t iNumItems,
10701076
const auto wBox1 = _mm256_loadu_pd(&iBoxes[wIdx + 1].mMinX);
10711077
const auto wBox2 = _mm256_loadu_pd(&iBoxes[wIdx + 2].mMinX);
10721078
const auto wBox3 = _mm256_loadu_pd(&iBoxes[wIdx + 3].mMinX);
1073-
const auto wBoxes01Lo = _mm256_unpacklo_pd(wBox0, wBox1);
1074-
const auto wBoxes01Hi = _mm256_unpackhi_pd(wBox0, wBox1);
1075-
const auto wBoxes23Lo = _mm256_unpacklo_pd(wBox2, wBox3);
1076-
const auto wBoxes23Hi = _mm256_unpackhi_pd(wBox2, wBox3);
1079+
const auto wBoxes01Lo = _mm256_shuffle_pd(wBox0, wBox1, 0x0);
1080+
const auto wBoxes01Hi = _mm256_shuffle_pd(wBox0, wBox1, 0xF);
1081+
const auto wBoxes23Lo = _mm256_shuffle_pd(wBox2, wBox3, 0x0);
1082+
const auto wBoxes23Hi = _mm256_shuffle_pd(wBox2, wBox3, 0xF);
10771083
const auto wMinX = _mm256_permute2f128_pd(wBoxes01Lo, wBoxes23Lo, 0x20);
10781084
const auto wMinY = _mm256_permute2f128_pd(wBoxes01Hi, wBoxes23Hi, 0x20);
10791085
const auto wMaxX = _mm256_permute2f128_pd(wBoxes01Lo, wBoxes23Lo, 0x31);
@@ -1101,10 +1107,10 @@ inline std::vector<uint32_t> computeHilbertValues<double>(size_t iNumItems,
11011107
const auto wBox1Lo = _mm_loadu_pd(&iBoxes[wIdx + 1].mMinX);
11021108
const auto wBox1Hi = _mm_loadu_pd(&iBoxes[wIdx + 1].mMaxX);
11031109

1104-
const auto wMinX = _mm_unpacklo_pd(wBox0Lo, wBox1Lo);
1105-
const auto wMinY = _mm_unpackhi_pd(wBox0Lo, wBox1Lo);
1106-
const auto wMaxX = _mm_unpacklo_pd(wBox0Hi, wBox1Hi);
1107-
const auto wMaxY = _mm_unpackhi_pd(wBox0Hi, wBox1Hi);
1110+
const auto wMinX = _mm_shuffle_pd(wBox0Lo, wBox1Lo, 0x0);
1111+
const auto wMinY = _mm_shuffle_pd(wBox0Lo, wBox1Lo, 0x3);
1112+
const auto wMaxX = _mm_shuffle_pd(wBox0Hi, wBox1Hi, 0x0);
1113+
const auto wMaxY = _mm_shuffle_pd(wBox0Hi, wBox1Hi, 0x3);
11081114

11091115
const auto wSumX = _mm_add_pd(wMinX, wMaxX);
11101116
const auto wSumY = _mm_add_pd(wMinY, wMaxY);

0 commit comments

Comments
 (0)