Skip to content

Commit eefcad9

Browse files
committed
Wired QP8_QB4W Operator APIs
1 parent f77e78c commit eefcad9

13 files changed

+696
-21
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,7 @@ IF(XNNPACK_BUILD_TESTS)
12461246
IF(XNNPACK_BUILD_LIBRARY)
12471247
# ---[ Launch heavy tests first.
12481248
SET(LIBRARY_SHARDED_TESTS
1249+
fully-connected-nc
12491250
batch-matrix-multiply-nc
12501251
batch-matrix-multiply
12511252
deconvolution-nhwc

src/configs/gemm-config.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ static struct xnn_gemm_config qd8_f32_qb4w_gemm_config = {0};
4040
static struct xnn_gemm_config qd8_f32_qc4w_gemm_config = {0};
4141
static struct xnn_gemm_config qd8_f32_qc8w_gemm_config = {0};
4242
static struct xnn_gemm_config qp8_f32_qc4w_gemm_config = {0};
43+
static struct xnn_gemm_config qp8_f32_qb4w_gemm_config = {0};
4344
static struct xnn_gemm_config qs8_qc8w_gemm_config = {0};
4445
static struct xnn_gemm_config qu8_gemm_config = {0};
4546

@@ -55,6 +56,7 @@ XNN_INIT_ONCE_GUARD(qd8_f32_qb4w_gemm);
5556
XNN_INIT_ONCE_GUARD(qd8_f32_qc4w_gemm);
5657
XNN_INIT_ONCE_GUARD(qd8_f32_qc8w_gemm);
5758
XNN_INIT_ONCE_GUARD(qp8_f32_qc4w_gemm);
59+
XNN_INIT_ONCE_GUARD(qp8_f32_qb4w_gemm);
5860
XNN_INIT_ONCE_GUARD(qs8_qc8w_gemm);
5961
XNN_INIT_ONCE_GUARD(qu8_gemm);
6062

@@ -1745,6 +1747,28 @@ static void init_qp8_f32_qc4w_gemm_config(void) {
17451747
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
17461748
}
17471749

1750+
static void init_qp8_f32_qb4w_gemm_config(void) {
1751+
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
1752+
const struct xnn_hardware_config* hardware_config =
1753+
xnn_init_hardware_config();
1754+
assert(hardware_config != NULL);
1755+
if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
1756+
#if XNN_ENABLE_ARM_DOTPROD
1757+
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot);
1758+
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
1759+
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
1760+
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
1761+
qp8_f32_qb4w_gemm_config.mr = 1;
1762+
qp8_f32_qb4w_gemm_config.nr = 8;
1763+
qp8_f32_qb4w_gemm_config.log2_kr = 4;
1764+
qp8_f32_qb4w_gemm_config.log2_sr = 1;
1765+
qp8_f32_qb4w_gemm_config.planes = 2;
1766+
qp8_f32_qb4w_gemm_config.mr_packed = 1;
1767+
#endif // XNN_ENABLE_ARM_DOTPROD
1768+
}
1769+
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
1770+
}
1771+
17481772
static void init_qd8_f32_qb4w_gemm_config(void) {
17491773
qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w;
17501774

@@ -3863,6 +3887,20 @@ XNN_INIT_ONCE(qp8_f32_qc4w_gemm);
38633887
return NULL;
38643888
}
38653889

3890+
struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() {
3891+
const struct xnn_hardware_config* hardware_config =
3892+
xnn_init_hardware_config();
3893+
if (hardware_config == NULL) {
3894+
return NULL;
3895+
}
3896+
XNN_INIT_ONCE(qp8_f32_qb4w_gemm);
3897+
// Only return the config pointer if it actually provides a kernel.
3898+
if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) {
3899+
return &qp8_f32_qb4w_gemm_config;
3900+
}
3901+
return NULL;
3902+
}
3903+
38663904
struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config() {
38673905
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
38683906
if (hardware_config == NULL) {

src/enums/operator-type.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
#include "xnnpack/operator-type.h"
1414

15-
static const uint16_t offset[169] = {
15+
static const uint16_t offset[170] = {
1616
0, 8, 22, 36, 50, 64, 78, 92, 119, 147, 175, 203, 230, 257, 289, 321, 364, 382, 400, 425, 451, 467, 483, 498, 513,
1717
535, 558, 581, 604, 627, 650, 673, 696, 719, 742, 760, 783, 806, 830, 848, 871, 895, 919, 943, 967, 1002, 1037, 1061,
1818
1085, 1109, 1123, 1138, 1153, 1173, 1199, 1225, 1262, 1288, 1318, 1344, 1376, 1408, 1434, 1461, 1488, 1505, 1522,
1919
1556, 1590, 1604, 1618, 1632, 1646, 1662, 1678, 1704, 1730, 1762, 1794, 1831, 1868, 1905, 1942, 1979, 2016, 2053,
20-
2079, 2111, 2137, 2152, 2186, 2220, 2254, 2288, 2322, 2356, 2386, 2416, 2436, 2456, 2477, 2498, 2519, 2540, 2554,
21-
2578, 2602, 2625, 2648, 2666, 2684, 2699, 2714, 2732, 2750, 2769, 2788, 2807, 2826, 2845, 2862, 2879, 2895, 2911,
22-
2944, 2977, 3005, 3033, 3061, 3089, 3116, 3143, 3160, 3177, 3218, 3259, 3277, 3295, 3313, 3331, 3346, 3362, 3378,
23-
3396, 3414, 3432, 3458, 3485, 3512, 3529, 3546, 3568, 3590, 3619, 3648, 3667, 3686, 3705, 3724, 3739, 3754, 3769,
24-
3784, 3803, 3823, 3843, 3863, 3884, 3905
20+
2090, 2116, 2148, 2174, 2189, 2223, 2257, 2291, 2325, 2359, 2393, 2423, 2453, 2473, 2493, 2514, 2535, 2556, 2577,
21+
2591, 2615, 2639, 2662, 2685, 2703, 2721, 2736, 2751, 2769, 2787, 2806, 2825, 2844, 2863, 2882, 2899, 2916, 2932,
22+
2948, 2981, 3014, 3042, 3070, 3098, 3126, 3153, 3180, 3197, 3214, 3255, 3296, 3314, 3332, 3350, 3368, 3383, 3399,
23+
3415, 3433, 3451, 3469, 3495, 3522, 3549, 3566, 3583, 3605, 3627, 3656, 3685, 3704, 3723, 3742, 3761, 3776, 3791,
24+
3806, 3821, 3840, 3860, 3880, 3900, 3921, 3942
2525
};
2626

2727
static const char data[] =
@@ -110,6 +110,7 @@ static const char data[] =
110110
"Fully Connected (NC, QD8, F32, QC4W)\0"
111111
"Fully Connected (NC, QD8, F32, QC8W)\0"
112112
"Fully Connected (NC, QP8, F32, QC4W)\0"
113+
"Fully Connected (NC, QP8, F32, QB4W)\0"
113114
"Fully Connected (NC, QS8)\0"
114115
"Fully Connected (NC, QS8, QC8W)\0"
115116
"Fully Connected (NC, QU8)\0"

src/enums/operator-type.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@
175175
string: "Fully Connected (NC, QD8, F32, QC8W)"
176176
- name: xnn_operator_type_fully_connected_nc_qp8_f32_qc4w
177177
string: "Fully Connected (NC, QP8, F32, QC4W)"
178+
- name: xnn_operator_type_fully_connected_nc_qp8_f32_qb4w
179+
string: "Fully Connected (NC, QP8, F32, QB4W)"
178180
- name: xnn_operator_type_fully_connected_nc_qs8
179181
string: "Fully Connected (NC, QS8)"
180182
- name: xnn_operator_type_fully_connected_nc_qs8_qc8w

src/operator-run.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,33 @@ void xnn_compute_qp8gemm(
527527
nr_block_start, mr_block_size, nr_block_size);
528528
}
529529

530+
void xnn_compute_hmp_qp8gemm_bl(
531+
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
532+
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
533+
size_t mr_block_size, size_t nr_block_size) {
534+
const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset(
535+
mr_block_start, context->k_scaled, context->mr, context->kr, context->sr);
536+
const size_t cm_stride = context->cm_stride;
537+
538+
context->qp8_bl_ukernel.function[uarch_index](
539+
mr_block_size, nr_block_size, context->k_scaled,
540+
(const void*)((uintptr_t)context->a + a_offset),
541+
(const void*)((uintptr_t)context->packed_w +
542+
nr_block_start * context->w_stride),
543+
(void*)((uintptr_t)context->c + mr_block_start * cm_stride +
544+
(nr_block_start << context->log2_csize)),
545+
cm_stride,
546+
/*dst_stride_col=*/sizeof(float), context->fused_params);
547+
}
548+
549+
void xnn_compute_qp8gemm_bl(
550+
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
551+
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
552+
size_t nr_block_size) {
553+
xnn_compute_hmp_qp8gemm_bl(context, XNN_UARCH_DEFAULT, mr_block_start,
554+
nr_block_start, mr_block_size, nr_block_size);
555+
}
556+
530557
void xnn_compute_hmp_dqgemm_bl(
531558
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
532559
uint32_t uarch_index,

0 commit comments

Comments
 (0)