Skip to content

Commit ef8f6b6

Browse files
committed
Add qs8/qu8 vlrelu kernels, configs aand tests.
1 parent eb4486b commit ef8f6b6

15 files changed

+336
-6
lines changed

bench/qs8-vlrelu.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,21 @@ static void qs8_vlrelu(
240240
->UseRealTime();
241241
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
242242

243+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
244+
BENCHMARK_CAPTURE(qs8_vlrelu, rvv_u1v,
245+
xnn_qs8_vlrelu_ukernel__rvv_u1v,
246+
xnn_init_qs8_lrelu_scalar_params,
247+
benchmark::utils::CheckRVV)
248+
->Apply(benchmark::utils::BinaryElementwiseParameters<int8_t, int8_t>)
249+
->UseRealTime();
250+
BENCHMARK_CAPTURE(qs8_vlrelu, rvv_u2v,
251+
xnn_qs8_vlrelu_ukernel__rvv_u2v,
252+
xnn_init_qs8_lrelu_scalar_params,
253+
benchmark::utils::CheckRVV)
254+
->Apply(benchmark::utils::BinaryElementwiseParameters<int8_t, int8_t>)
255+
->UseRealTime();
256+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
257+
243258
BENCHMARK_CAPTURE(qs8_vlrelu, scalar_andxor_u1,
244259
xnn_qs8_vlrelu_ukernel__scalar_andxor_u1,
245260
xnn_init_qs8_lrelu_scalar_params)

bench/qu8-vlrelu.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,21 @@ static void qu8_vlrelu(
240240
->UseRealTime();
241241
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
242242

243+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
244+
BENCHMARK_CAPTURE(qu8_vlrelu, rvv_u1v,
245+
xnn_qu8_vlrelu_ukernel__rvv_u1v,
246+
xnn_init_qu8_lrelu_scalar_params,
247+
benchmark::utils::CheckRVV)
248+
->Apply(benchmark::utils::BinaryElementwiseParameters<uint8_t, uint8_t>)
249+
->UseRealTime();
250+
BENCHMARK_CAPTURE(qu8_vlrelu, rvv_u2v,
251+
xnn_qu8_vlrelu_ukernel__rvv_u2v,
252+
xnn_init_qu8_lrelu_scalar_params,
253+
benchmark::utils::CheckRVV)
254+
->Apply(benchmark::utils::BinaryElementwiseParameters<uint8_t, uint8_t>)
255+
->UseRealTime();
256+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
257+
243258
BENCHMARK_CAPTURE(qu8_vlrelu, scalar_andxor_u1,
244259
xnn_qu8_vlrelu_ukernel__scalar_andxor_u1,
245260
xnn_init_qu8_lrelu_scalar_params)

cmake/gen/rvv_microkernels.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ SET(PROD_RVV_MICROKERNEL_SRCS
4444
src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c
4545
src/f32-vrelu/gen/f32-vrelu-rvv-u4v.c
4646
src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c
47+
src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c
4748
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c
4849
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c
50+
src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c
4951
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c
5052
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c
5153
src/x32-transposec/gen/x32-transposec-4x4-rvv.c
@@ -155,8 +157,10 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS
155157
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c
156158
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c
157159
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c
160+
src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c
158161
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c
159162
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c
163+
src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c
160164
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c
161165
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c
162166
src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c

gen/rvv_microkernels.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ PROD_RVV_MICROKERNEL_SRCS = [
4040
"src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c",
4141
"src/f32-vrelu/gen/f32-vrelu-rvv-u4v.c",
4242
"src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c",
43+
"src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c",
4344
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c",
4445
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c",
46+
"src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c",
4547
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c",
4648
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c",
4749
"src/x32-transposec/gen/x32-transposec-4x4-rvv.c",
@@ -152,8 +154,10 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [
152154
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c",
153155
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c",
154156
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c",
157+
"src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c",
155158
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c",
156159
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c",
160+
"src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c",
157161
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c",
158162
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c",
159163
"src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c",

scripts/generate-qs8-vlrelu.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=8 -D DATATYPE=QS8 -o src
9090
tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=4 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-armsimd32-u4.c &
9191
tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=8 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-armsimd32-u8.c &
9292

93+
################################ RISC-V Vector ################################
94+
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=1 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c &
95+
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=2 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c &
96+
97+
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=1 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c &
98+
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=2 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c &
99+
100+
93101
#################################### Scalar ###################################
94102
tools/xngen src/qs8-vlrelu/scalar-select.c.in -D BATCH_TILE=1 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-scalar-select-u1.c &
95103
tools/xngen src/qs8-vlrelu/scalar-select.c.in -D BATCH_TILE=2 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-scalar-select-u2.c &

src/configs/unary-elementwise-config.c

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,10 +1907,11 @@ static void init_qs8_lrelu_config(void) {
19071907
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
19081908
qs8_lrelu_config.element_tile = 4;
19091909
}
1910-
#elif XNN_ARCH_RISCV
1911-
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__scalar_andxor_u4;
1910+
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
1911+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
1912+
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__rvv_u2v;
19121913
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
1913-
qs8_lrelu_config.element_tile = 4;
1914+
qs8_lrelu_config.element_tile = hardware_config->vlenb / sizeof(int8_t) * 2;
19141915
#else
19151916
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__scalar_andxor_u4;
19161917
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
@@ -2149,10 +2150,11 @@ static void init_qu8_lrelu_config(void) {
21492150
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
21502151
qu8_lrelu_config.element_tile = 4;
21512152
}
2152-
#elif XNN_ARCH_RISCV
2153-
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__scalar_andxor_u4;
2153+
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
2154+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
2155+
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__rvv_u2v;
21542156
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
2155-
qu8_lrelu_config.element_tile = 4;
2157+
qu8_lrelu_config.element_tile = hardware_config->vlenb / sizeof(int8_t) * 2;
21562158
#else
21572159
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__scalar_andxor_u4;
21582160
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/qs8-vlrelu/rvv.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2024 Imagination Technologies, inc.
6+
//
7+
// This source code is licensed under the BSD-style license found in the
8+
// LICENSE file in the root directory of this source tree.
9+
10+
#include <assert.h>
11+
12+
#include <riscv_vector.h>
13+
14+
#include "xnnpack/vlrelu.h"
15+
16+
17+
void xnn_qs8_vlrelu_ukernel__rvv_u1v(
18+
size_t batch,
19+
const int8_t* input,
20+
int8_t* output,
21+
const struct xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
22+
{
23+
assert(batch != 0);
24+
assert(batch % sizeof(int8_t) == 0);
25+
assert(input != NULL);
26+
assert(output != NULL);
27+
28+
const int32_t input_zero_point = params->scalar.input_zero_point;
29+
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
30+
const int32_t multiplier_base = params->scalar.positive_multiplier;
31+
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
32+
int32_t n = __riscv_vsetvl_e8m1(batch);
33+
vint32m4_t bias_i32v = __riscv_vmv_v_x_i32m4(bias, n);
34+
35+
do {
36+
n = __riscv_vsetvl_e8m1(batch); batch -= n;
37+
38+
vint8m1_t in_i8v = __riscv_vle8_v_i8m1(input, n); input += n;
39+
vint16m2_t acc_i16v = __riscv_vwsub_vx_i16m2(in_i8v, input_zero_point, n);
40+
41+
vint32m4_t acc_i32v = __riscv_vwcvt_x_x_v_i32m4(acc_i16v, n);
42+
vint32m4_t sra_i32v = __riscv_vsra_vx_i32m4(acc_i32v, 31, n);
43+
vint32m4_t and_i32v = __riscv_vand_vx_i32m4(sra_i32v, multiplier_diff, n);
44+
vint32m4_t mult_i32v = __riscv_vxor_vx_i32m4(and_i32v, multiplier_base, n);
45+
acc_i32v = __riscv_vmacc_vv_i32m4(bias_i32v, acc_i32v, mult_i32v, n);
46+
47+
vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(acc_i32v, 8, __RISCV_VXRM_RDN, n);
48+
vint8m1_t out_i8v = __riscv_vnclip_wx_i8m1(out_i16v, 0, __RISCV_VXRM_RNU, n);
49+
__riscv_vse8_v_i8m1(output, out_i8v, n); output += n;
50+
} while (batch != 0);
51+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/qs8-vlrelu/rvv.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2024 Imagination Technologies, inc.
6+
//
7+
// This source code is licensed under the BSD-style license found in the
8+
// LICENSE file in the root directory of this source tree.
9+
10+
#include <assert.h>
11+
12+
#include <riscv_vector.h>
13+
14+
#include "xnnpack/vlrelu.h"
15+
16+
17+
void xnn_qs8_vlrelu_ukernel__rvv_u2v(
18+
size_t batch,
19+
const int8_t* input,
20+
int8_t* output,
21+
const struct xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
22+
{
23+
assert(batch != 0);
24+
assert(batch % sizeof(int8_t) == 0);
25+
assert(input != NULL);
26+
assert(output != NULL);
27+
28+
const int32_t input_zero_point = params->scalar.input_zero_point;
29+
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
30+
const int32_t multiplier_base = params->scalar.positive_multiplier;
31+
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
32+
int32_t n = __riscv_vsetvl_e8m2(batch);
33+
vint32m8_t bias_i32v = __riscv_vmv_v_x_i32m8(bias, n);
34+
35+
do {
36+
n = __riscv_vsetvl_e8m2(batch); batch -= n;
37+
38+
vint8m2_t in_i8v = __riscv_vle8_v_i8m2(input, n); input += n;
39+
vint16m4_t acc_i16v = __riscv_vwsub_vx_i16m4(in_i8v, input_zero_point, n);
40+
41+
vint32m8_t acc_i32v = __riscv_vwcvt_x_x_v_i32m8(acc_i16v, n);
42+
vint32m8_t sra_i32v = __riscv_vsra_vx_i32m8(acc_i32v, 31, n);
43+
vint32m8_t and_i32v = __riscv_vand_vx_i32m8(sra_i32v, multiplier_diff, n);
44+
vint32m8_t mult_i32v = __riscv_vxor_vx_i32m8(and_i32v, multiplier_base, n);
45+
acc_i32v = __riscv_vmacc_vv_i32m8(bias_i32v, acc_i32v, mult_i32v, n);
46+
47+
vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(acc_i32v, 8, __RISCV_VXRM_RDN, n);
48+
vint8m2_t out_i8v = __riscv_vnclip_wx_i8m2(out_i16v, 0, __RISCV_VXRM_RNU, n);
49+
__riscv_vse8_v_i8m2(output, out_i8v, n); output += n;
50+
} while (batch != 0);
51+
}

src/qs8-vlrelu/qs8-vlrelu.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ XNN_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_qs8_vlrelu_ukernel__neon_u16, 16,
2222
XNN_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_qs8_vlrelu_ukernel__neon_u32, 32, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
2323
#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
2424

25+
#if XNN_ENABLE_RISCV_VECTOR && (XNN_ARCH_RISCV)
26+
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vlrelu_ukernel__rvv_u1v, 1, true, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
27+
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vlrelu_ukernel__rvv_u2v, 2, true, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
28+
#endif // XNN_ENABLE_RISCV_VECTOR && (XNN_ARCH_RISCV)
29+
2530
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
2631
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vlrelu_ukernel__sse2_u16, 16, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
2732
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vlrelu_ukernel__sse2_u32, 32, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)

src/qs8-vlrelu/rvv.c.in

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2024 Imagination Technologies, inc.
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
$assert LMUL in [1, 2, 4, 8]
7+
$assert DATATYPE in ["QS8", "QU8"]
8+
#include <assert.h>
9+
10+
#include <riscv_vector.h>
11+
12+
#include "xnnpack/vlrelu.h"
13+
14+
$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
15+
16+
void xnn_${DATATYPE.lower()}_vlrelu_ukernel__rvv_u${LMUL}v(
17+
size_t batch,
18+
const ${XINT8_T}* input,
19+
${XINT8_T}* output,
20+
const struct xnn_${DATATYPE.lower()}_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
21+
{
22+
assert(batch != 0);
23+
assert(batch % sizeof(${XINT8_T}) == 0);
24+
assert(input != NULL);
25+
assert(output != NULL);
26+
27+
const int32_t input_zero_point = params->scalar.input_zero_point;
28+
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
29+
const int32_t multiplier_base = params->scalar.positive_multiplier;
30+
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
31+
int32_t n = __riscv_vsetvl_e8m${LMUL}(batch);
32+
vint32m${LMUL*4}_t bias_i32v = __riscv_vmv_v_x_i32m${LMUL*4}(bias, n);
33+
34+
do {
35+
n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;
36+
37+
$if DATATYPE == "QS8":
38+
vint8m${LMUL}_t in_i8v = __riscv_vle8_v_i8m${LMUL}(input, n); input += n;
39+
vint16m${LMUL*2}_t acc_i16v = __riscv_vwsub_vx_i16m${LMUL*2}(in_i8v, input_zero_point, n);
40+
$else:
41+
vuint8m${LMUL}_t in_u8v = __riscv_vle8_v_u8m${LMUL}(input, n); input += n;
42+
vuint16m${LMUL*2}_t acc_u16v = __riscv_vwsubu_vx_u16m${LMUL*2}(in_u8v, input_zero_point, n);
43+
vint16m${LMUL*2}_t acc_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(acc_u16v);
44+
45+
vint32m${LMUL*4}_t acc_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(acc_i16v, n);
46+
vint32m${LMUL*4}_t sra_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, 31, n);
47+
vint32m${LMUL*4}_t and_i32v = __riscv_vand_vx_i32m${LMUL*4}(sra_i32v, multiplier_diff, n);
48+
vint32m${LMUL*4}_t mult_i32v = __riscv_vxor_vx_i32m${LMUL*4}(and_i32v, multiplier_base, n);
49+
acc_i32v = __riscv_vmacc_vv_i32m${LMUL*4}(bias_i32v, acc_i32v, mult_i32v, n);
50+
51+
$if DATATYPE == "QS8":
52+
vint16m${LMUL*2}_t out_i16v = __riscv_vnclip_wx_i16m${LMUL*2}(acc_i32v, 8, __RISCV_VXRM_RDN, n);
53+
vint8m${LMUL}_t out_i8v = __riscv_vnclip_wx_i8m${LMUL}(out_i16v, 0, __RISCV_VXRM_RNU, n);
54+
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
55+
$else:
56+
acc_i32v = __riscv_vmax_vx_i32m${LMUL*4}(acc_i32v, 0, n);
57+
vuint32m${LMUL*4}_t out_u32v = __riscv_vreinterpret_v_i32m${LMUL*4}_u32m${LMUL*4}(acc_i32v);
58+
vuint16m${LMUL*2}_t out_u16v =__riscv_vnclipu_wx_u16m${LMUL*2}(out_u32v, 8, __RISCV_VXRM_RDN, n);
59+
vuint8m${LMUL}_t out_u8v = __riscv_vnclipu_wx_u8m${LMUL}(out_u16v, 0, __RISCV_VXRM_RNU, n);
60+
__riscv_vse8_v_u8m${LMUL}(output, out_u8v, n); output += n;
61+
} while (batch != 0);
62+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/qs8-vlrelu/rvv.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2024 Imagination Technologies, inc.
6+
//
7+
// This source code is licensed under the BSD-style license found in the
8+
// LICENSE file in the root directory of this source tree.
9+
10+
#include <assert.h>
11+
12+
#include <riscv_vector.h>
13+
14+
#include "xnnpack/vlrelu.h"
15+
16+
17+
void xnn_qu8_vlrelu_ukernel__rvv_u1v(
18+
size_t batch,
19+
const uint8_t* input,
20+
uint8_t* output,
21+
const struct xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
22+
{
23+
assert(batch != 0);
24+
assert(batch % sizeof(uint8_t) == 0);
25+
assert(input != NULL);
26+
assert(output != NULL);
27+
28+
const int32_t input_zero_point = params->scalar.input_zero_point;
29+
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
30+
const int32_t multiplier_base = params->scalar.positive_multiplier;
31+
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
32+
int32_t n = __riscv_vsetvl_e8m1(batch);
33+
vint32m4_t bias_i32v = __riscv_vmv_v_x_i32m4(bias, n);
34+
35+
do {
36+
n = __riscv_vsetvl_e8m1(batch); batch -= n;
37+
38+
vuint8m1_t in_u8v = __riscv_vle8_v_u8m1(input, n); input += n;
39+
vuint16m2_t acc_u16v = __riscv_vwsubu_vx_u16m2(in_u8v, input_zero_point, n);
40+
vint16m2_t acc_i16v = __riscv_vreinterpret_v_u16m2_i16m2(acc_u16v);
41+
42+
vint32m4_t acc_i32v = __riscv_vwcvt_x_x_v_i32m4(acc_i16v, n);
43+
vint32m4_t sra_i32v = __riscv_vsra_vx_i32m4(acc_i32v, 31, n);
44+
vint32m4_t and_i32v = __riscv_vand_vx_i32m4(sra_i32v, multiplier_diff, n);
45+
vint32m4_t mult_i32v = __riscv_vxor_vx_i32m4(and_i32v, multiplier_base, n);
46+
acc_i32v = __riscv_vmacc_vv_i32m4(bias_i32v, acc_i32v, mult_i32v, n);
47+
48+
acc_i32v = __riscv_vmax_vx_i32m4(acc_i32v, 0, n);
49+
vuint32m4_t out_u32v = __riscv_vreinterpret_v_i32m4_u32m4(acc_i32v);
50+
vuint16m2_t out_u16v =__riscv_vnclipu_wx_u16m2(out_u32v, 8, __RISCV_VXRM_RDN, n);
51+
vuint8m1_t out_u8v = __riscv_vnclipu_wx_u8m1(out_u16v, 0, __RISCV_VXRM_RNU, n);
52+
__riscv_vse8_v_u8m1(output, out_u8v, n); output += n;
53+
} while (batch != 0);
54+
}

0 commit comments

Comments
 (0)