Skip to content

Commit f77e78c

Browse files
committed
Neon Dot QP8 ukernels, tests, benchmarks
1 parent bcd169a commit f77e78c

16 files changed

+629
-10
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,7 @@ IF(XNNPACK_BUILD_TESTS)
16121612
qd8-f32-qc4w-gemm-minmax
16131613
qd8-f32-qc8w-igemm-minmax
16141614
qp8-f32-qc4w-gemm-minmax
1615+
qp8-f32-qb4w-gemm-minmax
16151616
qs8-qc8w-gemm-minmax-fp32
16161617
qs8-qc8w-igemm-minmax-fp32
16171618
qu8-gemm-minmax-fp32

bench/qp8-f32-qb4w-gemm.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
//
6+
// Auto-generated file. Do not edit!
7+
// Specification: test/qp8-f32-qb4w-gemm-minmax.yaml
8+
// Generator: tools/generate-gemm-test.py
9+
10+
#include <benchmark/benchmark.h>
11+
#include "bench/gemm-benchmark.h"
12+
#include "bench/utils.h"
13+
#include "xnnpack/common.h"
14+
#include "xnnpack/gemm.h"
15+
#include "xnnpack/isa-checks.h"
16+
#include "xnnpack/microfnptr.h"
17+
#include "xnnpack/microparams-init.h"
18+
#include "xnnpack/pack.h"
19+
#include "xnnpack/packw.h"
20+
21+
22+
#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64
23+
#if XNN_ENABLE_KLEIDIAI
24+
static void qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(benchmark::State& state, const char* net) {
25+
GEMMBenchmark(state,
26+
xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot,
27+
xnn_init_f32_qb4w_minmax_scalar_params,
28+
xnn_pack_kai_qb4_weights_and_biases,
29+
xnn_packed_stride_kai_qb4_weights_and_biases,
30+
/*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2,
31+
/*mr_packed=*/1,
32+
benchmark::utils::CheckNEONDOT);
33+
}
34+
35+
BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot)
36+
37+
static void qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot(benchmark::State& state, const char* net) {
38+
GEMMBenchmark(state,
39+
xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot,
40+
xnn_init_f32_qb4w_minmax_scalar_params,
41+
xnn_pack_kai_qb4_weights_and_biases,
42+
xnn_packed_stride_kai_qb4_weights_and_biases,
43+
/*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2,
44+
/*mr_packed=*/1,
45+
benchmark::utils::CheckNEONDOT);
46+
}
47+
48+
BENCHMARK_GEMM_BL(qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot)
49+
#endif // XNN_ENABLE_KLEIDIAI
50+
#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64
51+
52+
53+
#ifndef XNNPACK_BENCHMARK_NO_MAIN
54+
BENCHMARK_MAIN();
55+
#endif

cmake/gen/neondot_aarch64_microkernels.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111

1212
SET(PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS
13+
src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c
1314
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c)
1415

1516
SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS
1617
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-aarch64-neondot-ld128.c
1718
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c
1819
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c
1920
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c
21+
src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c
2022
src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c
2123
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c
2224
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c

gen/neondot_aarch64_microkernels.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Auto-generated file. Do not edit!
66
"""
77

88
PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [
9+
"src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c",
910
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c",
1011
]
1112

@@ -14,6 +15,7 @@ NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [
1415
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-aarch64-neondot-ld128.c",
1516
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-aarch64-neondot-ld128.c",
1617
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-aarch64-neondot-ld128.c",
18+
"src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c",
1719
"src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x4c16s2-aarch64-neondot.c",
1820
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-aarch64-neondot-ld128.c",
1921
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-aarch64-neondot-ld128.c",

scripts/generate-tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ tools/generate-gemm-test.py --spec test/qd8-f32-qc4w-gemm-minmax.yaml --output-t
4848
tools/generate-gemm-test.py --spec test/qd8-f32-qb4w-gemm-minmax.yaml --output-test test/qd8-f32-qb4w-gemm-minmax.cc --output-bench bench/qd8-f32-qb4w-gemm.cc &
4949

5050
tools/generate-gemm-test.py --spec test/qp8-f32-qc4w-gemm-minmax.yaml --output-test test/qp8-f32-qc4w-gemm-minmax.cc --output-bench bench/qp8-f32-qc4w-gemm.cc &
51+
tools/generate-gemm-test.py --spec test/qp8-f32-qb4w-gemm-minmax.yaml --output-test test/qp8-f32-qb4w-gemm-minmax.cc --output-bench bench/qp8-f32-qb4w-gemm.cc &
5152

5253
tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc &
5354

src/packing.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#if XNN_ENABLE_KLEIDIAI
2525
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h"
2626
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h"
27+
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"
2728
#endif // XNN_ENABLE_KLEIDIAI
2829

2930
#include <fp16/fp16.h>
@@ -1676,6 +1677,64 @@ void xnn_pack_kai_qs4_weights_and_biases(
16761677
&kai_params);
16771678
}
16781679
}
1680+
1681+
size_t xnn_packed_stride_kai_qb4_weights_and_biases(
1682+
const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size,
1683+
size_t extra_bytes) {
1684+
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
1685+
1686+
const size_t kai_num_bytes_sum_rhs = sizeof(float);
1687+
const size_t kai_num_bytes_bias = sizeof(float);
1688+
// perhaps derive Bf16 from gemm-config?
1689+
// This needs to be updated in the kleidi branch to be in header
1690+
// return kai_rhs_packed_stride(k, /*nr=*/1, kr, block_size, Bf16);
1691+
const size_t num_bytes_multiplier_rhs = sizeof(uint16_t);
1692+
const size_t num_blocks_per_row = k/block_size;
1693+
const size_t num_bytes_per_block = (block_size / 2) + num_bytes_multiplier_rhs;
1694+
return 1 * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
1695+
}
1696+
1697+
void xnn_pack_kai_qb4_weights_and_biases(
1698+
uint32_t flags, const struct xnn_gemm_config* gemm_config,
1699+
size_t input_channels, size_t output_channels, size_t groups,
1700+
size_t block_size, const void* accumulator_init, const void* weights,
1701+
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
1702+
size_t extra_data0_element_size,
1703+
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
1704+
size_t extra_data1_element_size, void* packed_weights_ptr,
1705+
const void* params) {
1706+
const uint32_t nr = gemm_config->nr;
1707+
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
1708+
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
1709+
const struct xnn_qs8_qc4w_packing_params* xnn_params =
1710+
reinterpret_cast<const struct xnn_qs8_qc4w_packing_params*>(params);
1711+
1712+
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
1713+
// no nxk as of now
1714+
xnn_log_fatal(
1715+
"KleidiAI does not currently have gio packing routine"
1716+
);
1717+
} else {
1718+
// Repack the packing params.
1719+
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params kai_params;
1720+
kai_params.lhs_zero_point = xnn_params->input_zero_point;
1721+
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
1722+
kai_params.scale_dt = Bf16;
1723+
size_t rhs_stride = round_up_po2(input_channels, 2) / 2;
1724+
size_t blocks_per_row = (input_channels + block_size - 1) / block_size;
1725+
kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
1726+
groups, output_channels, input_channels, nr, kr, sr,
1727+
/*bl=*/block_size,
1728+
/*rhs=*/reinterpret_cast<const uint8_t*>(weights),
1729+
/*rhs_stride=*/rhs_stride,
1730+
/*bias=*/reinterpret_cast<const float*>(extra_data0),
1731+
/*scale=*/reinterpret_cast<const uint16_t*>(extra_data1),
1732+
/*scale_stride=*/blocks_per_row * sizeof(uint16_t),
1733+
/*rhs_packed*/packed_weights_ptr,
1734+
/*extra_bytes=*/0,
1735+
&kai_params);
1736+
}
1737+
}
16791738
#endif // XNN_ENABLE_KLEIDIAI
16801739

16811740
void xnn_pack_f32_qs8w_gemm_gio_w(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <stddef.h>
7+
#include "xnnpack/log.h"
8+
#include "xnnpack/microparams.h"
9+
10+
#if XNN_ENABLE_KLEIDIAI
11+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
12+
#endif // XNN_ENABLE_KLEIDIAI
13+
14+
// Wraps the
15+
// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod` GEMM
16+
// microkernel with a name that is compatible with our tooling.
17+
void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot(
18+
size_t m, size_t n, size_t k, const void* lhs_packed,
19+
const void* rhs_packed, float* dst, size_t dst_stride_row,
20+
size_t dst_stride_col,
21+
const union xnn_f32_qb4w_minmax_params
22+
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) {
23+
#if XNN_ENABLE_KLEIDIAI
24+
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
25+
m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col,
26+
minmax_params->scalar.min, minmax_params->scalar.max);
27+
#else
28+
xnn_log_fatal(
29+
"Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without "
30+
"`XNN_ENABLE_KLEIDIAI`.");
31+
#endif // XNN_ENABLE_KLEIDIAI
32+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include <stddef.h>
7+
#include "xnnpack/log.h"
8+
#include "xnnpack/microparams.h"
9+
10+
#if XNN_ENABLE_KLEIDIAI
11+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"
12+
#endif // XNN_ENABLE_KLEIDIAI
13+
14+
// Wraps the
15+
// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod` GEMM
16+
// microkernel with a name that is compatible with our tooling.
17+
void xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot(
18+
size_t m, size_t n, size_t k, const void* lhs_packed,
19+
const void* rhs_packed, float* dst, size_t dst_stride_row,
20+
size_t dst_stride_col,
21+
const union xnn_f32_qb4w_minmax_params
22+
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) {
23+
#if XNN_ENABLE_KLEIDIAI
24+
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
25+
m, n, k, minmax_params->scalar.blocksize, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col,
26+
minmax_params->scalar.min, minmax_params->scalar.max);
27+
#else
28+
xnn_log_fatal(
29+
"Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without "
30+
"`XNN_ENABLE_KLEIDIAI`.");
31+
#endif // XNN_ENABLE_KLEIDIAI
32+
}

src/xnnpack/gemm.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,17 @@ DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_u
22062206
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm)
22072207
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm)
22082208

2209+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128)
2210+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128)
2211+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128)
2212+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128)
2213+
2214+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64)
2215+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64)
2216+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64)
2217+
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64)
2218+
2219+
22092220
#define DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \
22102221
XNN_INTERNAL void fn_name( \
22112222
size_t m, \
@@ -2226,16 +2237,21 @@ DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_u
22262237
DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2)
22272238
DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x8c16s2__aarch64_neoni8mm_mstep2)
22282239

2229-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128)
2230-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128)
2231-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128)
2232-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128)
2233-
2234-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64)
2235-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64)
2236-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64)
2237-
DECLARE_QD8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64)
2240+
#define DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \
2241+
XNN_INTERNAL void fn_name( \
2242+
size_t m, \
2243+
size_t n, \
2244+
size_t k, \
2245+
const void* lhs_packed, \
2246+
const void* rhs_packed, \
2247+
float* dst, \
2248+
size_t dst_stride_row, \
2249+
size_t dst_stride_col, \
2250+
const union xnn_f32_qb4w_minmax_params \
2251+
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);
22382252

2253+
DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot)
2254+
DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot)
22392255

22402256

22412257
#define DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \

src/xnnpack/microfnptr.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,18 @@ typedef void (*xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn)(
339339
union xnn_f32_minmax_params
340340
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);
341341

342+
typedef void (*xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn)(
343+
size_t m,
344+
size_t n,
345+
size_t k,
346+
const void* lhs_packed,
347+
const void* rhs_packed,
348+
float* dst,
349+
size_t dst_stride_row,
350+
size_t dst_stride_col,
351+
const union xnn_f32_qb4w_minmax_params
352+
minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]);
353+
342354
// GEMMINC: GEMM INCremental with Min+Max activation
343355

344356
typedef void (*xnn_f32_gemminc_minmax_ukernel_fn)(

src/xnnpack/pack.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,30 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases(
483483
size_t k, //
484484
size_t k_stride, //
485485
size_t extra_bytes);
486+
487+
XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases(
488+
uint32_t flags, //
489+
const struct xnn_gemm_config* gemm_config, //
490+
size_t input_channels, //
491+
size_t output_channels, //
492+
size_t groups, //
493+
size_t block_size, //
494+
const void* accumulator_init, //
495+
const void* weights, //
496+
xnn_init_scale_params_fn init_extra_data0_fn, //
497+
const void* extra_data0, //
498+
size_t extra_data0_element_size, //
499+
xnn_init_scale_params_fn init_extra_data1_fn, //
500+
const void* extra_data1, //
501+
size_t extra_data1_element_size, //
502+
void* packed_weights_ptr, //
503+
const void* params);
504+
505+
XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases(
506+
const struct xnn_gemm_config* gemm_config, //
507+
size_t k, //
508+
size_t block_size, //
509+
size_t extra_bytes);
486510
#endif // XNN_ENABLE_KLEIDIAI
487511

488512
XNN_INTERNAL void xnn_pack_qs8_to_qu8_gemm_gio_w(

src/xnnpack/packq.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,37 @@ XNN_INLINE static int8_t xnn_x8_packq_f32qp8_get_quantized(
8686
return *dst_ptr;
8787
}
8888

89+
XNN_INLINE static float xnn_x8_packq_f32qp8_get_recip_scale(
90+
size_t m_idx, const int8_t* lhs_packed, size_t k,
91+
size_t mr_packed, size_t kr, size_t sr) {
92+
const size_t k_internal = k_roundedup(k, kr, sr);
93+
const size_t dst_x = (m_idx % mr_packed);
94+
const size_t packed_offset =
95+
xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr);
96+
97+
// Get the quantization parameters.
98+
const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal;
99+
dst_ptr += dst_x * sizeof(int32_t);
100+
dst_ptr += mr_packed * sizeof(float);
101+
const float recip_scale = *(const float*)dst_ptr;
102+
return recip_scale;
103+
}
104+
105+
XNN_INLINE static float xnn_x8_packq_f32qp8_get_neg_nudged_zp(
106+
size_t m_idx, const int8_t* lhs_packed, size_t k,
107+
size_t mr_packed, size_t kr, size_t sr) {
108+
const size_t k_internal = k_roundedup(k, kr, sr);
109+
const size_t dst_x = (m_idx % mr_packed);
110+
const size_t packed_offset =
111+
xnn_x8_packq_f32qp8_packed_offset(m_idx, k, mr_packed, kr, sr);
112+
113+
// Get the quantization parameters.
114+
const int8_t* dst_ptr = lhs_packed + packed_offset + mr_packed * k_internal;
115+
dst_ptr += dst_x * sizeof(int32_t);
116+
const int32_t neg_nudged_zero_point = *(const int32_t*)dst_ptr;
117+
return neg_nudged_zero_point;
118+
}
119+
89120
XNN_INLINE static float xnn_x8_packq_f32qp8_get_dequantized(
90121
size_t m_idx, size_t k_idx, const int8_t* lhs_packed, size_t k,
91122
size_t mr_packed, size_t kr, size_t sr) {

0 commit comments

Comments
 (0)