Skip to content

Commit 5044fef

Browse files
authored
Fix: Turin kernels for spdot (#252)
1 parent d3ee357 commit 5044fef

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

include/simsimd/sparse.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
678678

679679
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
680680
if (a_length < 64 && b_length < 64) {
681-
simsimd_intersect_u16_serial(a, b, a_length, b_length, results);
681+
simsimd_spdot_weights_u16_serial(a, b, a_weights, b_weights, a_length, b_length, results);
682682
return;
683683
}
684684

@@ -751,9 +751,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( //
751751
a += a_step, a_weights += a_step;
752752
b += b_step, b_weights += b_step;
753753
}
754-
755-
simsimd_intersect_u16_serial(a, b, a_end - a, b_end - b, results);
756-
*results += intersection_size;
754+
simsimd_spdot_weights_u16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, results);
755+
results[0] += intersection_size;
756+
results[1] += _mm512_reduce_add_ps(_mm512_insertf32x8(_mm512_setzero_ps(), product_vec.ymmps, 0));
757757
}
758758

759759
SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
@@ -764,7 +764,7 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
764764

765765
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
766766
if (a_length < 64 && b_length < 64) {
767-
simsimd_intersect_u16_serial(a, b, a_length, b_length, results);
767+
simsimd_spdot_counts_u16_serial(a, b, a_weights, b_weights, a_length, b_length, results);
768768
return;
769769
}
770770

@@ -837,8 +837,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( //
837837
b += b_step, b_weights += b_step;
838838
}
839839

840-
simsimd_intersect_u16_serial(a, b, a_end - a, b_end - b, results);
841-
*results += intersection_size;
840+
simsimd_spdot_counts_u16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, results);
841+
results[0] += intersection_size;
842+
results[1] += _mm512_reduce_add_epi32(_mm512_inserti64x4(_mm512_setzero_si512(), product_vec.ymm, 0));
842843
}
843844

844845
#pragma clang attribute pop

0 commit comments

Comments
 (0)