|
| 1 | +cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner |
| 2 | + product op |
| 3 | + |
| 4 | +--- |
| 5 | + src/cpu/aarch64/acl_inner_product.hpp | 8 ++++++-- |
| 6 | + src/cpu/cpu_inner_product_list.cpp | 4 ++++ |
| 7 | + 2 files changed, 10 insertions(+), 2 deletions(-) |
| 8 | + |
| 9 | +diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp |
| 10 | +index a2be164f0..eca56b289 100644 |
| 11 | +--- a/src/cpu/aarch64/acl_inner_product.hpp |
| 12 | ++++ b/src/cpu/aarch64/acl_inner_product.hpp |
| 13 | +@@ -99,9 +99,13 @@ struct acl_inner_product_fwd_t : public primitive_t { |
| 14 | + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) |
| 15 | + && attr()->has_default_values( |
| 16 | + primitive_attr_t::skip_mask_t::post_ops, f32); |
| 17 | ++ const bool is_fp32_bf16_ok |
| 18 | ++ = expect_data_types(f32, bf16, f32, f32, undef) |
| 19 | ++ && attr()->has_default_values( |
| 20 | ++ primitive_attr_t::skip_mask_t::post_ops, f32); |
| 21 | + const bool ok = is_fwd() && !has_zero_dim_memory() |
| 22 | +- && utils::one_of(true, is_fp16_ok, is_fp32_ok) |
| 23 | +- && weights_md_.format_kind == format_kind::any |
| 24 | ++ && utils::one_of( |
| 25 | ++ true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) |
| 26 | + && set_default_params() == status::success; |
| 27 | + |
| 28 | + if (!ok) return status::unimplemented; |
| 29 | +diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp |
| 30 | +index fdd7b1776..5a3dc1ea7 100644 |
| 31 | +--- a/src/cpu/cpu_inner_product_list.cpp |
| 32 | ++++ b/src/cpu/cpu_inner_product_list.cpp |
| 33 | +@@ -83,6 +83,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() |
| 34 | + CPU_INSTANCE(ref_inner_product_fwd_t) |
| 35 | + nullptr, |
| 36 | + }}, |
| 37 | ++ {{forward, f32, bf16, f32}, { |
| 38 | ++ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) |
| 39 | ++ nullptr, |
| 40 | ++ }}, |
| 41 | + {{backward_data, f32, f32, f32}, REG_BWD_PK({ |
| 42 | + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32 |
| 43 | + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>) |
| 44 | +-- |
| 45 | +2.34.1 |
| 46 | + |
0 commit comments