Skip to content

Commit f921334

Browse files
authored
Merge pull request #808 from Ka-zam/saturated_sum_kernels
new saturated kernels
2 parents cda245e + 382ac93 commit f921334

File tree

5 files changed

+1408
-0
lines changed

5 files changed

+1408
-0
lines changed
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
/* -*- c++ -*- */
2+
/*
3+
* Copyright 2025 Magnus Lundmark <[email protected]>
4+
*
5+
* This file is part of VOLK
6+
*
7+
* SPDX-License-Identifier: LGPL-3.0-or-later
8+
*/
9+
10+
/*!
11+
* \page volk_16i_x2_add_saturated_16i
12+
*
13+
* \b Overview
14+
*
15+
* Adds two int16_t vectors element-wise with saturation. Results are clamped
16+
* to the range [-32768, 32767] to prevent overflow wraparound.
17+
*
18+
* <b>Dispatcher Prototype</b>
19+
* \code
20+
* void volk_16i_x2_add_saturated_16i(int16_t* outVector, const int16_t* inVectorA, const
21+
* int16_t* inVectorB, unsigned int num_points) \endcode
22+
*
23+
* \b Inputs
24+
* \li inVectorA: First input vector.
25+
* \li inVectorB: Second input vector.
26+
* \li num_points: Vector length.
27+
*
28+
* \b Outputs
29+
* \li outVector: Saturated sum output.
30+
*
31+
* \b Example
32+
* \code
33+
* unsigned int N = 8;
34+
* unsigned int align = volk_get_alignment();
35+
* int16_t* a = (int16_t*)volk_malloc(N * sizeof(int16_t), align);
36+
* int16_t* b = (int16_t*)volk_malloc(N * sizeof(int16_t), align);
37+
* int16_t* result = (int16_t*)volk_malloc(N * sizeof(int16_t), align);
38+
*
39+
* // Values that will cause saturation
40+
* a[0] = 30000; b[0] = 10000; // 40000 -> saturates to 32767
41+
* a[1] = -30000; b[1] = -10000; // -40000 -> saturates to -32768
42+
*
43+
* volk_16i_x2_add_saturated_16i(result, a, b, N);
44+
* // result[0] == 32767, result[1] == -32768
45+
*
46+
* volk_free(a);
47+
* volk_free(b);
48+
* volk_free(result);
49+
* \endcode
50+
*/
51+
52+
#ifndef INCLUDED_volk_16i_x2_add_saturated_16i_u_H
53+
#define INCLUDED_volk_16i_x2_add_saturated_16i_u_H
54+
55+
#include <inttypes.h>
56+
57+
#ifdef LV_HAVE_GENERIC
58+
59+
static inline void volk_16i_x2_add_saturated_16i_generic(int16_t* outVector,
60+
const int16_t* inVectorA,
61+
const int16_t* inVectorB,
62+
unsigned int num_points)
63+
{
64+
for (unsigned int i = 0; i < num_points; i++) {
65+
int16_t a = inVectorA[i];
66+
int16_t b = inVectorB[i];
67+
int16_t sum = a + b;
68+
// Overflow if a and b have same sign but sum has different sign
69+
int16_t overflow = ((a ^ sum) & (b ^ sum)) >> 15;
70+
// Saturation value: 32767 if a >= 0, -32768 if a < 0
71+
int16_t sat_val = (a >> 15) ^ 0x7FFF;
72+
outVector[i] = (overflow & sat_val) | (~overflow & sum);
73+
}
74+
}
75+
76+
#endif /* LV_HAVE_GENERIC */
77+
78+
79+
#ifdef LV_HAVE_SSE2
80+
#include <emmintrin.h>
81+
82+
static inline void volk_16i_x2_add_saturated_16i_u_sse2(int16_t* outVector,
83+
const int16_t* inVectorA,
84+
const int16_t* inVectorB,
85+
unsigned int num_points)
86+
{
87+
const unsigned int eighthPoints = num_points / 8;
88+
unsigned int number = 0;
89+
90+
for (; number < eighthPoints; number++) {
91+
__m128i a = _mm_loadu_si128((const __m128i*)(inVectorA + 8 * number));
92+
__m128i b = _mm_loadu_si128((const __m128i*)(inVectorB + 8 * number));
93+
__m128i result = _mm_adds_epi16(a, b);
94+
_mm_storeu_si128((__m128i*)(outVector + 8 * number), result);
95+
}
96+
97+
for (number = eighthPoints * 8; number < num_points; number++) {
98+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
99+
if (sum > 32767)
100+
sum = 32767;
101+
else if (sum < -32768)
102+
sum = -32768;
103+
outVector[number] = (int16_t)sum;
104+
}
105+
}
106+
107+
#endif /* LV_HAVE_SSE2 */
108+
109+
110+
#ifdef LV_HAVE_AVX2
111+
#include <immintrin.h>
112+
113+
static inline void volk_16i_x2_add_saturated_16i_u_avx2(int16_t* outVector,
114+
const int16_t* inVectorA,
115+
const int16_t* inVectorB,
116+
unsigned int num_points)
117+
{
118+
const unsigned int sixteenthPoints = num_points / 16;
119+
unsigned int number = 0;
120+
121+
for (; number < sixteenthPoints; number++) {
122+
__m256i a = _mm256_loadu_si256((const __m256i*)(inVectorA + 16 * number));
123+
__m256i b = _mm256_loadu_si256((const __m256i*)(inVectorB + 16 * number));
124+
__m256i result = _mm256_adds_epi16(a, b);
125+
_mm256_storeu_si256((__m256i*)(outVector + 16 * number), result);
126+
}
127+
128+
for (number = sixteenthPoints * 16; number < num_points; number++) {
129+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
130+
if (sum > 32767)
131+
sum = 32767;
132+
else if (sum < -32768)
133+
sum = -32768;
134+
outVector[number] = (int16_t)sum;
135+
}
136+
}
137+
138+
#endif /* LV_HAVE_AVX2 */
139+
140+
141+
#ifdef LV_HAVE_AVX512BW
142+
#include <immintrin.h>
143+
144+
static inline void volk_16i_x2_add_saturated_16i_u_avx512bw(int16_t* outVector,
145+
const int16_t* inVectorA,
146+
const int16_t* inVectorB,
147+
unsigned int num_points)
148+
{
149+
const unsigned int thirtysecondPoints = num_points / 32;
150+
unsigned int number = 0;
151+
152+
for (; number < thirtysecondPoints; number++) {
153+
__m512i a = _mm512_loadu_si512((const __m512i*)(inVectorA + 32 * number));
154+
__m512i b = _mm512_loadu_si512((const __m512i*)(inVectorB + 32 * number));
155+
__m512i result = _mm512_adds_epi16(a, b);
156+
_mm512_storeu_si512((__m512i*)(outVector + 32 * number), result);
157+
}
158+
159+
for (number = thirtysecondPoints * 32; number < num_points; number++) {
160+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
161+
if (sum > 32767)
162+
sum = 32767;
163+
else if (sum < -32768)
164+
sum = -32768;
165+
outVector[number] = (int16_t)sum;
166+
}
167+
}
168+
169+
#endif /* LV_HAVE_AVX512BW */
170+
171+
172+
#endif /* INCLUDED_volk_16i_x2_add_saturated_16i_u_H */
173+
174+
175+
#ifndef INCLUDED_volk_16i_x2_add_saturated_16i_a_H
176+
#define INCLUDED_volk_16i_x2_add_saturated_16i_a_H
177+
178+
#include <inttypes.h>
179+
180+
#ifdef LV_HAVE_SSE2
181+
#include <emmintrin.h>
182+
183+
static inline void volk_16i_x2_add_saturated_16i_a_sse2(int16_t* outVector,
184+
const int16_t* inVectorA,
185+
const int16_t* inVectorB,
186+
unsigned int num_points)
187+
{
188+
const unsigned int eighthPoints = num_points / 8;
189+
unsigned int number = 0;
190+
191+
for (; number < eighthPoints; number++) {
192+
__m128i a = _mm_load_si128((const __m128i*)(inVectorA + 8 * number));
193+
__m128i b = _mm_load_si128((const __m128i*)(inVectorB + 8 * number));
194+
__m128i result = _mm_adds_epi16(a, b);
195+
_mm_store_si128((__m128i*)(outVector + 8 * number), result);
196+
}
197+
198+
for (number = eighthPoints * 8; number < num_points; number++) {
199+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
200+
if (sum > 32767)
201+
sum = 32767;
202+
else if (sum < -32768)
203+
sum = -32768;
204+
outVector[number] = (int16_t)sum;
205+
}
206+
}
207+
208+
#endif /* LV_HAVE_SSE2 */
209+
210+
211+
#ifdef LV_HAVE_AVX2
212+
#include <immintrin.h>
213+
214+
static inline void volk_16i_x2_add_saturated_16i_a_avx2(int16_t* outVector,
215+
const int16_t* inVectorA,
216+
const int16_t* inVectorB,
217+
unsigned int num_points)
218+
{
219+
const unsigned int sixteenthPoints = num_points / 16;
220+
unsigned int number = 0;
221+
222+
for (; number < sixteenthPoints; number++) {
223+
__m256i a = _mm256_load_si256((const __m256i*)(inVectorA + 16 * number));
224+
__m256i b = _mm256_load_si256((const __m256i*)(inVectorB + 16 * number));
225+
__m256i result = _mm256_adds_epi16(a, b);
226+
_mm256_store_si256((__m256i*)(outVector + 16 * number), result);
227+
}
228+
229+
for (number = sixteenthPoints * 16; number < num_points; number++) {
230+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
231+
if (sum > 32767)
232+
sum = 32767;
233+
else if (sum < -32768)
234+
sum = -32768;
235+
outVector[number] = (int16_t)sum;
236+
}
237+
}
238+
239+
#endif /* LV_HAVE_AVX2 */
240+
241+
242+
#ifdef LV_HAVE_AVX512BW
243+
#include <immintrin.h>
244+
245+
static inline void volk_16i_x2_add_saturated_16i_a_avx512bw(int16_t* outVector,
246+
const int16_t* inVectorA,
247+
const int16_t* inVectorB,
248+
unsigned int num_points)
249+
{
250+
const unsigned int thirtysecondPoints = num_points / 32;
251+
unsigned int number = 0;
252+
253+
for (; number < thirtysecondPoints; number++) {
254+
__m512i a = _mm512_load_si512((const __m512i*)(inVectorA + 32 * number));
255+
__m512i b = _mm512_load_si512((const __m512i*)(inVectorB + 32 * number));
256+
__m512i result = _mm512_adds_epi16(a, b);
257+
_mm512_store_si512((__m512i*)(outVector + 32 * number), result);
258+
}
259+
260+
for (number = thirtysecondPoints * 32; number < num_points; number++) {
261+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
262+
if (sum > 32767)
263+
sum = 32767;
264+
else if (sum < -32768)
265+
sum = -32768;
266+
outVector[number] = (int16_t)sum;
267+
}
268+
}
269+
270+
#endif /* LV_HAVE_AVX512BW */
271+
272+
273+
#ifdef LV_HAVE_NEON
274+
#include <arm_neon.h>
275+
276+
static inline void volk_16i_x2_add_saturated_16i_neon(int16_t* outVector,
277+
const int16_t* inVectorA,
278+
const int16_t* inVectorB,
279+
unsigned int num_points)
280+
{
281+
const unsigned int eighthPoints = num_points / 8;
282+
unsigned int number = 0;
283+
284+
for (; number < eighthPoints; number++) {
285+
int16x8_t a = vld1q_s16(inVectorA + 8 * number);
286+
int16x8_t b = vld1q_s16(inVectorB + 8 * number);
287+
vst1q_s16(outVector + 8 * number, vqaddq_s16(a, b));
288+
}
289+
290+
for (number = eighthPoints * 8; number < num_points; number++) {
291+
int32_t sum = (int32_t)inVectorA[number] + (int32_t)inVectorB[number];
292+
if (sum > 32767)
293+
sum = 32767;
294+
else if (sum < -32768)
295+
sum = -32768;
296+
outVector[number] = (int16_t)sum;
297+
}
298+
}
299+
300+
#endif /* LV_HAVE_NEON */
301+
302+
303+
#ifdef LV_HAVE_NEONV8
304+
#include <arm_neon.h>
305+
#include <volk/volk_common.h>
306+
307+
static inline void volk_16i_x2_add_saturated_16i_neonv8(int16_t* outVector,
308+
const int16_t* inVectorA,
309+
const int16_t* inVectorB,
310+
unsigned int num_points)
311+
{
312+
const unsigned int sixteenthPoints = num_points / 16;
313+
unsigned int number = 0;
314+
315+
for (; number < sixteenthPoints; number++) {
316+
__VOLK_PREFETCH(inVectorA + 32);
317+
__VOLK_PREFETCH(inVectorB + 32);
318+
int16x8_t a0 = vld1q_s16(inVectorA);
319+
int16x8_t b0 = vld1q_s16(inVectorB);
320+
int16x8_t a1 = vld1q_s16(inVectorA + 8);
321+
int16x8_t b1 = vld1q_s16(inVectorB + 8);
322+
vst1q_s16(outVector, vqaddq_s16(a0, b0));
323+
vst1q_s16(outVector + 8, vqaddq_s16(a1, b1));
324+
inVectorA += 16;
325+
inVectorB += 16;
326+
outVector += 16;
327+
}
328+
329+
for (number = sixteenthPoints * 16; number < num_points; number++) {
330+
int32_t sum = (int32_t)(*inVectorA++) + (int32_t)(*inVectorB++);
331+
if (sum > 32767)
332+
sum = 32767;
333+
else if (sum < -32768)
334+
sum = -32768;
335+
*outVector++ = (int16_t)sum;
336+
}
337+
}
338+
339+
#endif /* LV_HAVE_NEONV8 */
340+
341+
342+
#ifdef LV_HAVE_RVV
343+
#include <riscv_vector.h>
344+
345+
static inline void volk_16i_x2_add_saturated_16i_rvv(int16_t* outVector,
346+
const int16_t* inVectorA,
347+
const int16_t* inVectorB,
348+
unsigned int num_points)
349+
{
350+
size_t n = num_points;
351+
for (size_t vl; n > 0; n -= vl, inVectorA += vl, inVectorB += vl, outVector += vl) {
352+
vl = __riscv_vsetvl_e16m8(n);
353+
vint16m8_t a = __riscv_vle16_v_i16m8(inVectorA, vl);
354+
vint16m8_t b = __riscv_vle16_v_i16m8(inVectorB, vl);
355+
__riscv_vse16(outVector, __riscv_vsadd(a, b, vl), vl);
356+
}
357+
}
358+
359+
#endif /* LV_HAVE_RVV */
360+
361+
362+
#endif /* INCLUDED_volk_16i_x2_add_saturated_16i_a_H */

0 commit comments

Comments
 (0)