Skip to content

Commit

Permalink
Wired QP8_QB4W Operator APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcr229 committed Sep 4, 2024
1 parent f77e78c commit eefcad9
Show file tree
Hide file tree
Showing 13 changed files with 696 additions and 21 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ IF(XNNPACK_BUILD_TESTS)
IF(XNNPACK_BUILD_LIBRARY)
# ---[ Launch heavy tests first.
SET(LIBRARY_SHARDED_TESTS
fully-connected-nc
batch-matrix-multiply-nc
batch-matrix-multiply
deconvolution-nhwc
Expand Down
38 changes: 38 additions & 0 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static struct xnn_gemm_config qd8_f32_qb4w_gemm_config = {0};
static struct xnn_gemm_config qd8_f32_qc4w_gemm_config = {0};
static struct xnn_gemm_config qd8_f32_qc8w_gemm_config = {0};
static struct xnn_gemm_config qp8_f32_qc4w_gemm_config = {0};
static struct xnn_gemm_config qp8_f32_qb4w_gemm_config = {0};
static struct xnn_gemm_config qs8_qc8w_gemm_config = {0};
static struct xnn_gemm_config qu8_gemm_config = {0};

Expand All @@ -55,6 +56,7 @@ XNN_INIT_ONCE_GUARD(qd8_f32_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f32_qc4w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f32_qc8w_gemm);
XNN_INIT_ONCE_GUARD(qp8_f32_qc4w_gemm);
XNN_INIT_ONCE_GUARD(qp8_f32_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qs8_qc8w_gemm);
XNN_INIT_ONCE_GUARD(qu8_gemm);

Expand Down Expand Up @@ -1745,6 +1747,28 @@ static void init_qp8_f32_qc4w_gemm_config(void) {
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_qp8_f32_qb4w_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
#if XNN_ENABLE_ARM_DOTPROD
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot);
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.mr = 1;
qp8_f32_qb4w_gemm_config.nr = 8;
qp8_f32_qb4w_gemm_config.log2_kr = 4;
qp8_f32_qb4w_gemm_config.log2_sr = 1;
qp8_f32_qb4w_gemm_config.planes = 2;
qp8_f32_qb4w_gemm_config.mr_packed = 1;
#endif // XNN_ENABLE_ARM_DOTPROD
}
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w;

Expand Down Expand Up @@ -3863,6 +3887,20 @@ XNN_INIT_ONCE(qp8_f32_qc4w_gemm);
return NULL;
}

struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() {
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(qp8_f32_qb4w_gemm);
// Only return the config pointer if it actually provides a kernel.
if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) {
return &qp8_f32_qb4w_gemm_config;
}
return NULL;
}

struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
Expand Down
13 changes: 7 additions & 6 deletions src/enums/operator-type.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

#include "xnnpack/operator-type.h"

static const uint16_t offset[169] = {
static const uint16_t offset[170] = {
0, 8, 22, 36, 50, 64, 78, 92, 119, 147, 175, 203, 230, 257, 289, 321, 364, 382, 400, 425, 451, 467, 483, 498, 513,
535, 558, 581, 604, 627, 650, 673, 696, 719, 742, 760, 783, 806, 830, 848, 871, 895, 919, 943, 967, 1002, 1037, 1061,
1085, 1109, 1123, 1138, 1153, 1173, 1199, 1225, 1262, 1288, 1318, 1344, 1376, 1408, 1434, 1461, 1488, 1505, 1522,
1556, 1590, 1604, 1618, 1632, 1646, 1662, 1678, 1704, 1730, 1762, 1794, 1831, 1868, 1905, 1942, 1979, 2016, 2053,
2079, 2111, 2137, 2152, 2186, 2220, 2254, 2288, 2322, 2356, 2386, 2416, 2436, 2456, 2477, 2498, 2519, 2540, 2554,
2578, 2602, 2625, 2648, 2666, 2684, 2699, 2714, 2732, 2750, 2769, 2788, 2807, 2826, 2845, 2862, 2879, 2895, 2911,
2944, 2977, 3005, 3033, 3061, 3089, 3116, 3143, 3160, 3177, 3218, 3259, 3277, 3295, 3313, 3331, 3346, 3362, 3378,
3396, 3414, 3432, 3458, 3485, 3512, 3529, 3546, 3568, 3590, 3619, 3648, 3667, 3686, 3705, 3724, 3739, 3754, 3769,
3784, 3803, 3823, 3843, 3863, 3884, 3905
2090, 2116, 2148, 2174, 2189, 2223, 2257, 2291, 2325, 2359, 2393, 2423, 2453, 2473, 2493, 2514, 2535, 2556, 2577,
2591, 2615, 2639, 2662, 2685, 2703, 2721, 2736, 2751, 2769, 2787, 2806, 2825, 2844, 2863, 2882, 2899, 2916, 2932,
2948, 2981, 3014, 3042, 3070, 3098, 3126, 3153, 3180, 3197, 3214, 3255, 3296, 3314, 3332, 3350, 3368, 3383, 3399,
3415, 3433, 3451, 3469, 3495, 3522, 3549, 3566, 3583, 3605, 3627, 3656, 3685, 3704, 3723, 3742, 3761, 3776, 3791,
3806, 3821, 3840, 3860, 3880, 3900, 3921, 3942
};

static const char data[] =
Expand Down Expand Up @@ -110,6 +110,7 @@ static const char data[] =
"Fully Connected (NC, QD8, F32, QC4W)\0"
"Fully Connected (NC, QD8, F32, QC8W)\0"
"Fully Connected (NC, QP8, F32, QC4W)\0"
"Fully Connected (NC, QP8, F32, QB4W)\0"
"Fully Connected (NC, QS8)\0"
"Fully Connected (NC, QS8, QC8W)\0"
"Fully Connected (NC, QU8)\0"
Expand Down
2 changes: 2 additions & 0 deletions src/enums/operator-type.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@
string: "Fully Connected (NC, QD8, F32, QC8W)"
- name: xnn_operator_type_fully_connected_nc_qp8_f32_qc4w
string: "Fully Connected (NC, QP8, F32, QC4W)"
- name: xnn_operator_type_fully_connected_nc_qp8_f32_qb4w
string: "Fully Connected (NC, QP8, F32, QB4W)"
- name: xnn_operator_type_fully_connected_nc_qs8
string: "Fully Connected (NC, QS8)"
- name: xnn_operator_type_fully_connected_nc_qs8_qc8w
Expand Down
27 changes: 27 additions & 0 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,33 @@ void xnn_compute_qp8gemm(
nr_block_start, mr_block_size, nr_block_size);
}

void xnn_compute_hmp_qp8gemm_bl(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset(
mr_block_start, context->k_scaled, context->mr, context->kr, context->sr);
const size_t cm_stride = context->cm_stride;

context->qp8_bl_ukernel.function[uarch_index](
mr_block_size, nr_block_size, context->k_scaled,
(const void*)((uintptr_t)context->a + a_offset),
(const void*)((uintptr_t)context->packed_w +
nr_block_start * context->w_stride),
(void*)((uintptr_t)context->c + mr_block_start * cm_stride +
(nr_block_start << context->log2_csize)),
cm_stride,
/*dst_stride_col=*/sizeof(float), context->fused_params);
}

void xnn_compute_qp8gemm_bl(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size) {
xnn_compute_hmp_qp8gemm_bl(context, XNN_UARCH_DEFAULT, mr_block_start,
nr_block_start, mr_block_size, nr_block_size);
}

void xnn_compute_hmp_dqgemm_bl(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
Expand Down
Loading

0 comments on commit eefcad9

Please sign in to comment.