Skip to content

Commit

Permalink
fix init_qp8_qb4_config
Browse files Browse the repository at this point in the history
  • Loading branch information
mcr229 committed Sep 13, 2024
1 parent 4cf6c02 commit 277bf61
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
68 changes: 34 additions & 34 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1734,39 +1734,39 @@ static void init_qp8_f32_qc4w_gemm_config(void) {
}

static void init_qp8_f32_qb4w_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) {
#if XNN_ENABLE_ARM_I8MM
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot);
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__aarch64_neoni8mm_mstep2);
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.mr = 8;
qp8_f32_qb4w_gemm_config.nr = 4;
qp8_f32_qb4w_gemm_config.log2_kr = 4;
qp8_f32_qb4w_gemm_config.log2_sr = 1;
qp8_f32_qb4w_gemm_config.planes = 2;
qp8_f32_qb4w_gemm_config.mr_packed = 4;
#endif // XNN_ENABLE_ARM_I8MM
} else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
#if XNN_ENABLE_ARM_DOTPROD
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot);
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.mr = 1;
qp8_f32_qb4w_gemm_config.nr = 8;
qp8_f32_qb4w_gemm_config.log2_kr = 4;
qp8_f32_qb4w_gemm_config.log2_sr = 1;
qp8_f32_qb4w_gemm_config.planes = 2;
qp8_f32_qb4w_gemm_config.mr_packed = 1;
#endif // XNN_ENABLE_ARM_DOTPROD
}
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) {
#if XNN_ENABLE_ARM_I8MM
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot);
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2);
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.mr = 8;
qp8_f32_qb4w_gemm_config.nr = 4;
qp8_f32_qb4w_gemm_config.log2_kr = 4;
qp8_f32_qb4w_gemm_config.log2_sr = 1;
qp8_f32_qb4w_gemm_config.planes = 2;
qp8_f32_qb4w_gemm_config.mr_packed = 4;
#endif // XNN_ENABLE_ARM_I8MM
} else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
#if XNN_ENABLE_ARM_DOTPROD
qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_bl_ukernel(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot);
qp8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params;
qp8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qb4_weights_and_biases;
qp8_f32_qb4w_gemm_config.mr = 1;
qp8_f32_qb4w_gemm_config.nr = 8;
qp8_f32_qb4w_gemm_config.log2_kr = 4;
qp8_f32_qb4w_gemm_config.log2_sr = 1;
qp8_f32_qb4w_gemm_config.planes = 2;
qp8_f32_qb4w_gemm_config.mr_packed = 1;
#endif // XNN_ENABLE_ARM_DOTPROD
}
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_qd8_f32_qb4w_gemm_config(void) {
Expand Down Expand Up @@ -3926,7 +3926,7 @@ struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() {
}
XNN_INIT_ONCE(qp8_f32_qb4w_gemm);
// Only return the config pointer if it actually provides a kernel.
if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) {
if (qp8_f32_qb4w_gemm_config.minmax.qp8gemm_bl[0].function[0] != NULL) {
return &qp8_f32_qb4w_gemm_config;
}
return NULL;
Expand Down
3 changes: 2 additions & 1 deletion src/subgraph/convert.c
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ enum xnn_status xnn_define_convert(
if ((flags & XNN_FLAG_MAYBE_PACK_FOR_GEMM) &&
input_value->datatype == xnn_datatype_fp32 &&
output_value->datatype == xnn_datatype_qdint8 &&
xnn_init_qp8_f32_qc4w_gemm_config() != NULL) {
(xnn_init_qp8_f32_qc4w_gemm_config() != NULL ||
xnn_init_qp8_f32_qb4w_gemm_config() != NULL)) {
xnn_log_debug("Coercing type of output ID #%" PRIu32
" of %s operator from `%s` to `%s`.",
output_id, xnn_node_type_to_string(xnn_node_type_convert),
Expand Down

0 comments on commit 277bf61

Please sign in to comment.