Skip to content

Commit

Permalink
Neon Dot QP8 ukernels, tests, benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
mcr229 committed Sep 4, 2024
1 parent bcd169a commit f77e78c
Show file tree
Hide file tree
Showing 16 changed files with 629 additions and 10 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,7 @@ IF(XNNPACK_BUILD_TESTS)
qd8-f32-qc4w-gemm-minmax
qd8-f32-qc8w-igemm-minmax
qp8-f32-qc4w-gemm-minmax
qp8-f32-qb4w-gemm-minmax
qs8-qc8w-gemm-minmax-fp32
qs8-qc8w-igemm-minmax-fp32
qu8-gemm-minmax-fp32
Expand Down
55 changes: 55 additions & 0 deletions bench/qp8-f32-qb4w-gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
// Specification: test/qp8-f32-qb4w-gemm-minmax.yaml
// Generator: tools/generate-gemm-test.py

#include <benchmark/benchmark.h>
#include "bench/gemm-benchmark.h"
#include "bench/utils.h"
#include "xnnpack/common.h"
#include "xnnpack/gemm.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/pack.h"
#include "xnnpack/packw.h"


#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64
#if XNN_ENABLE_KLEIDIAI
static void qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot,
xnn_init_f32_qb4w_minmax_scalar_params,
xnn_pack_kai_qb4_weights_and_biases,
xnn_packed_stride_kai_qb4_weights_and_biases,
/*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2,
/*mr_packed=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot)

static void qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot,
xnn_init_f32_qb4w_minmax_scalar_params,
xnn_pack_kai_qb4_weights_and_biases,
xnn_packed_stride_kai_qb4_weights_and_biases,
/*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2,
/*mr_packed=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot)
#endif // XNN_ENABLE_KLEIDIAI
#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64


#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif
2 changes: 2 additions & 0 deletions cmake/gen/neondot_aarch64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@


SET(PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS
src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c)

SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-aarch64-neondot-ld128.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c
src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c
Expand Down
2 changes: 2 additions & 0 deletions gen/neondot_aarch64_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Auto-generated file. Do not edit!
"""

PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [
"src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c",
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c",
]

Expand All @@ -14,6 +15,7 @@ NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c",
"src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c",
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c",
Expand Down
1 change: 1 addition & 0 deletions scripts/generate-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ tools/generate-gemm-test.py --spec test/qd8-f32-qc4w-gemm-minmax.yaml --output-t
tools/generate-gemm-test.py --spec test/qd8-f32-qb4w-gemm-minmax.yaml --output-test test/qd8-f32-qb4w-gemm-minmax.cc --output-bench bench/qd8-f32-qb4w-gemm.cc &

tools/generate-gemm-test.py --spec test/qp8-f32-qc4w-gemm-minmax.yaml --output-test test/qp8-f32-qc4w-gemm-minmax.cc --output-bench bench/qp8-f32-qc4w-gemm.cc &
tools/generate-gemm-test.py --spec test/qp8-f32-qb4w-gemm-minmax.yaml --output-test test/qp8-f32-qb4w-gemm-minmax.cc --output-bench bench/qp8-f32-qb4w-gemm.cc &

tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc &

Expand Down
59 changes: 59 additions & 0 deletions src/packing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#if XNN_ENABLE_KLEIDIAI
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"
#endif // XNN_ENABLE_KLEIDIAI

#include <fp16/fp16.h>
Expand Down Expand Up @@ -1676,6 +1677,64 @@ void xnn_pack_kai_qs4_weights_and_biases(
&kai_params);
}
}

size_t xnn_packed_stride_kai_qb4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size,
size_t extra_bytes) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;

const size_t kai_num_bytes_sum_rhs = sizeof(float);
const size_t kai_num_bytes_bias = sizeof(float);
// perhaps derive Bf16 from gemm-config?
// This needs to be updated in the kleidi branch to be in header
// return kai_rhs_packed_stride(k, /*nr=*/1, kr, block_size, Bf16);
const size_t num_bytes_multiplier_rhs = sizeof(uint16_t);
const size_t num_blocks_per_row = k/block_size;
const size_t num_bytes_per_block = (block_size / 2) + num_bytes_multiplier_rhs;
return 1 * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
}

void xnn_pack_kai_qb4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t block_size, const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const struct xnn_qs8_qc4w_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_qc4w_packing_params*>(params);

if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
// no nxk as of now
xnn_log_fatal(
"KleidiAI does not currently have gio packing routine"
);
} else {
// Repack the packing params.
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
kai_params.scale_dt = Bf16;
size_t rhs_stride = round_up_po2(input_channels, 2) / 2;
size_t blocks_per_row = (input_channels + block_size - 1) / block_size;
kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
groups, output_channels, input_channels, nr, kr, sr,
/*bl=*/block_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights),
/*rhs_stride=*/rhs_stride,
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const uint16_t*>(extra_data1),
/*scale_stride=*/blocks_per_row * sizeof(uint16_t),
/*rhs_packed*/packed_weights_ptr,
/*extra_bytes=*/0,
&kai_params);
}
}
#endif // XNN_ENABLE_KLEIDIAI

void xnn_pack_f32_qs8w_gemm_gio_w(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <stddef.h>
#include "xnnpack/log.h"
#include "xnnpack/microparams.h"

#if XNN_ENABLE_KLEIDIAI
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#endif // XNN_ENABLE_KLEIDIAI

// Wraps the
// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod` GEMM
// microkernel with a name that is compatible with our tooling.
void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(
size_t m, size_t n, size_t k, const void* lhs_packed,
const void* rhs_packed, float* dst, size_t dst_stride_row,
size_t dst_stride_col,
const union xnn_f32_qb4w_minmax_params
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) {
#if XNN_ENABLE_KLEIDIAI
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col,
minmax_params->scalar.min, minmax_params->scalar.max);
#else
xnn_log_fatal(
"Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without "
"`XNN_ENABLE_KLEIDIAI`.");
#endif // XNN_ENABLE_KLEIDIAI
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <stddef.h>
#include "xnnpack/log.h"
#include "xnnpack/microparams.h"

#if XNN_ENABLE_KLEIDIAI
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"
#endif // XNN_ENABLE_KLEIDIAI

// Wraps the
// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod` GEMM
// microkernel with a name that is compatible with our tooling.
void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot(
size_t m, size_t n, size_t k, const void* lhs_packed,
const void* rhs_packed, float* dst, size_t dst_stride_row,
size_t dst_stride_col,
const union xnn_f32_qb4w_minmax_params
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) {
#if XNN_ENABLE_KLEIDIAI
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col,
minmax_params->scalar.min, minmax_params->scalar.max);
#else
xnn_log_fatal(
"Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without "
"`XNN_ENABLE_KLEIDIAI`.");
#endif // XNN_ENABLE_KLEIDIAI
}
34 changes: 25 additions & 9 deletions src/xnnpack/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,17 @@ DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_u
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm)

DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128)

DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64)


#define DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \
XNN_INTERNAL void fn_name( \
size_t m, \
Expand All @@ -2226,16 +2237,21 @@ DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_u
DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2)
DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x8c16s2__aarch64_neoni8mm_mstep2)

DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128)

DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64)
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64)
#define DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \
XNN_INTERNAL void fn_name( \
size_t m, \
size_t n, \
size_t k, \
const void* lhs_packed, \
const void* rhs_packed, \
float* dst, \
size_t dst_stride_row, \
size_t dst_stride_col, \
const union xnn_f32_qb4w_minmax_params \
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);

DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot)
DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot)


#define DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \
Expand Down
12 changes: 12 additions & 0 deletions src/xnnpack/microfnptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,18 @@ typedef void (*xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn)(
union xnn_f32_minmax_params
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);

typedef void (*xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn)(
size_t m,
size_t n,
size_t k,
const void* lhs_packed,
const void* rhs_packed,
float* dst,
size_t dst_stride_row,
size_t dst_stride_col,
const union xnn_f32_qb4w_minmax_params
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);

// GEMMINC: GEMM INCremental with Min+Max activation

typedef void (*xnn_f32_gemminc_minmax_ukernel_fn)(
Expand Down
24 changes: 24 additions & 0 deletions src/xnnpack/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases(
size_t k, //
size_t k_stride, //
size_t extra_bytes);

XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases(
uint32_t flags, //
const struct xnn_gemm_config* gemm_config, //
size_t input_channels, //
size_t output_channels, //
size_t groups, //
size_t block_size, //
const void* accumulator_init, //
const void* weights, //
xnn_init_scale_params_fn init_extra_data0_fn, //
const void* extra_data0, //
size_t extra_data0_element_size, //
xnn_init_scale_params_fn init_extra_data1_fn, //
const void* extra_data1, //
size_t extra_data1_element_size, //
void* packed_weights_ptr, //
const void* params);

XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, //
size_t k, //
size_t block_size, //
size_t extra_bytes);
#endif // XNN_ENABLE_KLEIDIAI

XNN_INTERNAL void xnn_pack_qs8_to_qu8_gemm_gio_w(
Expand Down
31 changes: 31 additions & 0 deletions src/xnnpack/packq.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@ XNN_INLINE static int8_t xnn_x8_packq_f32qp8_get_quantized(
return *dst_ptr;
}

XNN_INLINE static float xnn_x8_packq_f32qp8_get_recip_scale(
size_t m_idx, const int8_t* lhs_packed, size_t k,
size_t mr_packed, size_t kr, size_t sr) {
const size_t k_internal = k_roundedup(k, kr, sr);
const size_t dst_x = (m_idx % mr_packed);
const size_t packed_offset =
xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr);

// Get the quantization parameters.
const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal;
dst_ptr += dst_x * sizeof(int32_t);
dst_ptr += mr_packed * sizeof(float);
const float recip_scale = *(const float*)dst_ptr;
return recip_scale;
}

XNN_INLINE static float xnn_x8_packq_f32qp8_get_neg_nudged_zp(
size_t m_idx, const int8_t* lhs_packed, size_t k,
size_t mr_packed, size_t kr, size_t sr) {
const size_t k_internal = k_roundedup(k, kr, sr);
const size_t dst_x = (m_idx % mr_packed);
const size_t packed_offset =
xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr);

// Get the quantization parameters.
const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal;
dst_ptr += dst_x * sizeof(int32_t);
const int32_t neg_nudged_zero_point = *(const int32_t*)dst_ptr;
return neg_nudged_zero_point;
}

XNN_INLINE static float xnn_x8_packq_f32qp8_get_dequantized(
size_t m_idx, size_t k_idx, const int8_t* lhs_packed, size_t k,
size_t mr_packed, size_t kr, size_t sr) {
Expand Down
Loading

0 comments on commit f77e78c

Please sign in to comment.