Skip to content

Commit

Permalink
Add qs8/qu8 vlrelu kernels, configs aand tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
oliIMG committed Sep 5, 2024
1 parent eb4486b commit ef8f6b6
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 6 deletions.
15 changes: 15 additions & 0 deletions bench/qs8-vlrelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,21 @@ static void qs8_vlrelu(
->UseRealTime();
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD

#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
BENCHMARK_CAPTURE(qs8_vlrelu, rvv_u1v,
xnn_qs8_vlrelu_ukernel__rvv_u1v,
xnn_init_qs8_lrelu_scalar_params,
benchmark::utils::CheckRVV)
->Apply(benchmark::utils::BinaryElementwiseParameters<int8_t, int8_t>)
->UseRealTime();
BENCHMARK_CAPTURE(qs8_vlrelu, rvv_u2v,
xnn_qs8_vlrelu_ukernel__rvv_u2v,
xnn_init_qs8_lrelu_scalar_params,
benchmark::utils::CheckRVV)
->Apply(benchmark::utils::BinaryElementwiseParameters<int8_t, int8_t>)
->UseRealTime();
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV

BENCHMARK_CAPTURE(qs8_vlrelu, scalar_andxor_u1,
xnn_qs8_vlrelu_ukernel__scalar_andxor_u1,
xnn_init_qs8_lrelu_scalar_params)
Expand Down
15 changes: 15 additions & 0 deletions bench/qu8-vlrelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,21 @@ static void qu8_vlrelu(
->UseRealTime();
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD

#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
BENCHMARK_CAPTURE(qu8_vlrelu, rvv_u1v,
xnn_qu8_vlrelu_ukernel__rvv_u1v,
xnn_init_qu8_lrelu_scalar_params,
benchmark::utils::CheckRVV)
->Apply(benchmark::utils::BinaryElementwiseParameters<uint8_t, uint8_t>)
->UseRealTime();
BENCHMARK_CAPTURE(qu8_vlrelu, rvv_u2v,
xnn_qu8_vlrelu_ukernel__rvv_u2v,
xnn_init_qu8_lrelu_scalar_params,
benchmark::utils::CheckRVV)
->Apply(benchmark::utils::BinaryElementwiseParameters<uint8_t, uint8_t>)
->UseRealTime();
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV

BENCHMARK_CAPTURE(qu8_vlrelu, scalar_andxor_u1,
xnn_qu8_vlrelu_ukernel__scalar_andxor_u1,
xnn_init_qu8_lrelu_scalar_params)
Expand Down
4 changes: 4 additions & 0 deletions cmake/gen/rvv_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ SET(PROD_RVV_MICROKERNEL_SRCS
src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c
src/f32-vrelu/gen/f32-vrelu-rvv-u4v.c
src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c
src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c
src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c
src/x32-transposec/gen/x32-transposec-4x4-rvv.c
Expand Down Expand Up @@ -155,8 +157,10 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c
src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c
src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c
src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c
src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c
src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c
src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c
src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c
Expand Down
4 changes: 4 additions & 0 deletions gen/rvv_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ PROD_RVV_MICROKERNEL_SRCS = [
"src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c",
"src/f32-vrelu/gen/f32-vrelu-rvv-u4v.c",
"src/f32-vrsqrt/gen/f32-vrsqrt-rvv-rsqrt-u4v.c",
"src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c",
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u2v.c",
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u2v.c",
"src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c",
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u2v.c",
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u2v.c",
"src/x32-transposec/gen/x32-transposec-4x4-rvv.c",
Expand Down Expand Up @@ -152,8 +154,10 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x4v-minmax-rvv.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x4v-minmax-rvv.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x4v-minmax-rvv.c",
"src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c",
"src/qs8-vmul/gen/qs8-vmul-minmax-f32-rvv-u1v.c",
"src/qs8-vmulc/gen/qs8-vmulc-minmax-f32-rvv-u1v.c",
"src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c",
"src/qu8-vmul/gen/qu8-vmul-minmax-f32-rvv-u1v.c",
"src/qu8-vmulc/gen/qu8-vmulc-minmax-f32-rvv-u1v.c",
"src/x32-packw/gen/x32-packw-x1v-gemm-goi-rvv-u2.c",
Expand Down
8 changes: 8 additions & 0 deletions scripts/generate-qs8-vlrelu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=8 -D DATATYPE=QS8 -o src
tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=4 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-armsimd32-u4.c &
tools/xngen src/qs8-vlrelu/armsimd32.c.in -D BATCH_TILE=8 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-armsimd32-u8.c &

################################ RISC-V Vector ################################
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=1 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c &
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=2 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c &

tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=1 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c &
tools/xngen src/qs8-vlrelu/rvv.c.in -D LMUL=2 -D DATATYPE=QU8 -o src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u2v.c &


#################################### Scalar ###################################
tools/xngen src/qs8-vlrelu/scalar-select.c.in -D BATCH_TILE=1 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-scalar-select-u1.c &
tools/xngen src/qs8-vlrelu/scalar-select.c.in -D BATCH_TILE=2 -D DATATYPE=QS8 -o src/qs8-vlrelu/gen/qs8-vlrelu-scalar-select-u2.c &
Expand Down
14 changes: 8 additions & 6 deletions src/configs/unary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1907,10 +1907,11 @@ static void init_qs8_lrelu_config(void) {
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
qs8_lrelu_config.element_tile = 4;
}
#elif XNN_ARCH_RISCV
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__scalar_andxor_u4;
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__rvv_u2v;
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
qs8_lrelu_config.element_tile = 4;
qs8_lrelu_config.element_tile = hardware_config->vlenb / sizeof(int8_t) * 2;
#else
qs8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qs8_vlrelu_ukernel__scalar_andxor_u4;
qs8_lrelu_config.init.qs8_lrelu = xnn_init_qs8_lrelu_scalar_params;
Expand Down Expand Up @@ -2149,10 +2150,11 @@ static void init_qu8_lrelu_config(void) {
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
qu8_lrelu_config.element_tile = 4;
}
#elif XNN_ARCH_RISCV
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__scalar_andxor_u4;
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__rvv_u2v;
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
qu8_lrelu_config.element_tile = 4;
qu8_lrelu_config.element_tile = hardware_config->vlenb / sizeof(int8_t) * 2;
#else
qu8_lrelu_config.ukernel = (xnn_vunary_ukernel_fn) xnn_qu8_vlrelu_ukernel__scalar_andxor_u4;
qu8_lrelu_config.init.qu8_lrelu = xnn_init_qu8_lrelu_scalar_params;
Expand Down
51 changes: 51 additions & 0 deletions src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-vlrelu/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vlrelu.h"


void xnn_qs8_vlrelu_ukernel__rvv_u1v(
size_t batch,
const int8_t* input,
int8_t* output,
const struct xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(int8_t) == 0);
assert(input != NULL);
assert(output != NULL);

const int32_t input_zero_point = params->scalar.input_zero_point;
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
const int32_t multiplier_base = params->scalar.positive_multiplier;
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
int32_t n = __riscv_vsetvl_e8m1(batch);
vint32m4_t bias_i32v = __riscv_vmv_v_x_i32m4(bias, n);

do {
n = __riscv_vsetvl_e8m1(batch); batch -= n;

vint8m1_t in_i8v = __riscv_vle8_v_i8m1(input, n); input += n;
vint16m2_t acc_i16v = __riscv_vwsub_vx_i16m2(in_i8v, input_zero_point, n);

vint32m4_t acc_i32v = __riscv_vwcvt_x_x_v_i32m4(acc_i16v, n);
vint32m4_t sra_i32v = __riscv_vsra_vx_i32m4(acc_i32v, 31, n);
vint32m4_t and_i32v = __riscv_vand_vx_i32m4(sra_i32v, multiplier_diff, n);
vint32m4_t mult_i32v = __riscv_vxor_vx_i32m4(and_i32v, multiplier_base, n);
acc_i32v = __riscv_vmacc_vv_i32m4(bias_i32v, acc_i32v, mult_i32v, n);

vint16m2_t out_i16v = __riscv_vnclip_wx_i16m2(acc_i32v, 8, __RISCV_VXRM_RDN, n);
vint8m1_t out_i8v = __riscv_vnclip_wx_i8m1(out_i16v, 0, __RISCV_VXRM_RNU, n);
__riscv_vse8_v_i8m1(output, out_i8v, n); output += n;
} while (batch != 0);
}
51 changes: 51 additions & 0 deletions src/qs8-vlrelu/gen/qs8-vlrelu-rvv-u2v.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-vlrelu/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vlrelu.h"


void xnn_qs8_vlrelu_ukernel__rvv_u2v(
size_t batch,
const int8_t* input,
int8_t* output,
const struct xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(int8_t) == 0);
assert(input != NULL);
assert(output != NULL);

const int32_t input_zero_point = params->scalar.input_zero_point;
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
const int32_t multiplier_base = params->scalar.positive_multiplier;
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
int32_t n = __riscv_vsetvl_e8m2(batch);
vint32m8_t bias_i32v = __riscv_vmv_v_x_i32m8(bias, n);

do {
n = __riscv_vsetvl_e8m2(batch); batch -= n;

vint8m2_t in_i8v = __riscv_vle8_v_i8m2(input, n); input += n;
vint16m4_t acc_i16v = __riscv_vwsub_vx_i16m4(in_i8v, input_zero_point, n);

vint32m8_t acc_i32v = __riscv_vwcvt_x_x_v_i32m8(acc_i16v, n);
vint32m8_t sra_i32v = __riscv_vsra_vx_i32m8(acc_i32v, 31, n);
vint32m8_t and_i32v = __riscv_vand_vx_i32m8(sra_i32v, multiplier_diff, n);
vint32m8_t mult_i32v = __riscv_vxor_vx_i32m8(and_i32v, multiplier_base, n);
acc_i32v = __riscv_vmacc_vv_i32m8(bias_i32v, acc_i32v, mult_i32v, n);

vint16m4_t out_i16v = __riscv_vnclip_wx_i16m4(acc_i32v, 8, __RISCV_VXRM_RDN, n);
vint8m2_t out_i8v = __riscv_vnclip_wx_i8m2(out_i16v, 0, __RISCV_VXRM_RNU, n);
__riscv_vse8_v_i8m2(output, out_i8v, n); output += n;
} while (batch != 0);
}
5 changes: 5 additions & 0 deletions src/qs8-vlrelu/qs8-vlrelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ XNN_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_qs8_vlrelu_ukernel__neon_u16, 16,
XNN_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_qs8_vlrelu_ukernel__neon_u32, 32, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64

#if XNN_ENABLE_RISCV_VECTOR && (XNN_ARCH_RISCV)
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vlrelu_ukernel__rvv_u1v, 1, true, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
XNN_UKERNEL_WITH_PARAMS(xnn_arch_riscv_vector, xnn_qs8_vlrelu_ukernel__rvv_u2v, 2, true, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
#endif // XNN_ENABLE_RISCV_VECTOR && (XNN_ARCH_RISCV)

#if XNN_ARCH_X86 || XNN_ARCH_X86_64
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vlrelu_ukernel__sse2_u16, 16, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
XNN_UKERNEL_WITH_PARAMS(0, xnn_qs8_vlrelu_ukernel__sse2_u32, 32, false, int8_t, union xnn_qs8_lrelu_minmax_params, xnn_init_qs8_lrelu_scalar_params)
Expand Down
62 changes: 62 additions & 0 deletions src/qs8-vlrelu/rvv.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
$assert DATATYPE in ["QS8", "QU8"]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vlrelu.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]

void xnn_${DATATYPE.lower()}_vlrelu_ukernel__rvv_u${LMUL}v(
size_t batch,
const ${XINT8_T}* input,
${XINT8_T}* output,
const struct xnn_${DATATYPE.lower()}_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(${XINT8_T}) == 0);
assert(input != NULL);
assert(output != NULL);

const int32_t input_zero_point = params->scalar.input_zero_point;
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
const int32_t multiplier_base = params->scalar.positive_multiplier;
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
int32_t n = __riscv_vsetvl_e8m${LMUL}(batch);
vint32m${LMUL*4}_t bias_i32v = __riscv_vmv_v_x_i32m${LMUL*4}(bias, n);

do {
n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;

$if DATATYPE == "QS8":
vint8m${LMUL}_t in_i8v = __riscv_vle8_v_i8m${LMUL}(input, n); input += n;
vint16m${LMUL*2}_t acc_i16v = __riscv_vwsub_vx_i16m${LMUL*2}(in_i8v, input_zero_point, n);
$else:
vuint8m${LMUL}_t in_u8v = __riscv_vle8_v_u8m${LMUL}(input, n); input += n;
vuint16m${LMUL*2}_t acc_u16v = __riscv_vwsubu_vx_u16m${LMUL*2}(in_u8v, input_zero_point, n);
vint16m${LMUL*2}_t acc_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(acc_u16v);

vint32m${LMUL*4}_t acc_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(acc_i16v, n);
vint32m${LMUL*4}_t sra_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, 31, n);
vint32m${LMUL*4}_t and_i32v = __riscv_vand_vx_i32m${LMUL*4}(sra_i32v, multiplier_diff, n);
vint32m${LMUL*4}_t mult_i32v = __riscv_vxor_vx_i32m${LMUL*4}(and_i32v, multiplier_base, n);
acc_i32v = __riscv_vmacc_vv_i32m${LMUL*4}(bias_i32v, acc_i32v, mult_i32v, n);

$if DATATYPE == "QS8":
vint16m${LMUL*2}_t out_i16v = __riscv_vnclip_wx_i16m${LMUL*2}(acc_i32v, 8, __RISCV_VXRM_RDN, n);
vint8m${LMUL}_t out_i8v = __riscv_vnclip_wx_i8m${LMUL}(out_i16v, 0, __RISCV_VXRM_RNU, n);
__riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
$else:
acc_i32v = __riscv_vmax_vx_i32m${LMUL*4}(acc_i32v, 0, n);
vuint32m${LMUL*4}_t out_u32v = __riscv_vreinterpret_v_i32m${LMUL*4}_u32m${LMUL*4}(acc_i32v);
vuint16m${LMUL*2}_t out_u16v =__riscv_vnclipu_wx_u16m${LMUL*2}(out_u32v, 8, __RISCV_VXRM_RDN, n);
vuint8m${LMUL}_t out_u8v = __riscv_vnclipu_wx_u8m${LMUL}(out_u16v, 0, __RISCV_VXRM_RNU, n);
__riscv_vse8_v_u8m${LMUL}(output, out_u8v, n); output += n;
} while (batch != 0);
}
54 changes: 54 additions & 0 deletions src/qu8-vlrelu/gen/qu8-vlrelu-rvv-u1v.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Auto-generated file. Do not edit!
// Template: src/qs8-vlrelu/rvv.c.in
// Generator: tools/xngen
//
// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vlrelu.h"


void xnn_qu8_vlrelu_ukernel__rvv_u1v(
size_t batch,
const uint8_t* input,
uint8_t* output,
const struct xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(batch != 0);
assert(batch % sizeof(uint8_t) == 0);
assert(input != NULL);
assert(output != NULL);

const int32_t input_zero_point = params->scalar.input_zero_point;
const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
const int32_t multiplier_base = params->scalar.positive_multiplier;
const int32_t bias = (params->scalar.output_zero_point << 8) + 128;
int32_t n = __riscv_vsetvl_e8m1(batch);
vint32m4_t bias_i32v = __riscv_vmv_v_x_i32m4(bias, n);

do {
n = __riscv_vsetvl_e8m1(batch); batch -= n;

vuint8m1_t in_u8v = __riscv_vle8_v_u8m1(input, n); input += n;
vuint16m2_t acc_u16v = __riscv_vwsubu_vx_u16m2(in_u8v, input_zero_point, n);
vint16m2_t acc_i16v = __riscv_vreinterpret_v_u16m2_i16m2(acc_u16v);

vint32m4_t acc_i32v = __riscv_vwcvt_x_x_v_i32m4(acc_i16v, n);
vint32m4_t sra_i32v = __riscv_vsra_vx_i32m4(acc_i32v, 31, n);
vint32m4_t and_i32v = __riscv_vand_vx_i32m4(sra_i32v, multiplier_diff, n);
vint32m4_t mult_i32v = __riscv_vxor_vx_i32m4(and_i32v, multiplier_base, n);
acc_i32v = __riscv_vmacc_vv_i32m4(bias_i32v, acc_i32v, mult_i32v, n);

acc_i32v = __riscv_vmax_vx_i32m4(acc_i32v, 0, n);
vuint32m4_t out_u32v = __riscv_vreinterpret_v_i32m4_u32m4(acc_i32v);
vuint16m2_t out_u16v =__riscv_vnclipu_wx_u16m2(out_u32v, 8, __RISCV_VXRM_RDN, n);
vuint8m1_t out_u8v = __riscv_vnclipu_wx_u8m1(out_u16v, 0, __RISCV_VXRM_RNU, n);
__riscv_vse8_v_u8m1(output, out_u8v, n); output += n;
} while (batch != 0);
}
Loading

0 comments on commit ef8f6b6

Please sign in to comment.