Skip to content

Commit 77d657e

Browse files
committed
[aarch64] patch mkldnn acl inner product to accelerate torch.compile() for bert
1 parent b92da8c commit 77d657e

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

aarch64_linux/aarch64_wheel_ci_build.py

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def parse_arguments():
108108
# work around to fix Raspberry pie crash
109109
print("Applying mkl-dnn patch to fix readdir crash")
110110
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-readdir-crash.patch")
111+
# patch acl inner product to accelerate torch.compile() path
112+
print("Applying mkl-dnn patch to acl inner product")
113+
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
111114
os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
112115
pytorch_wheel_name = complete_wheel("pytorch")
113116
print(f"Build Compelete. Created {pytorch_wheel_name}..")

aarch64_linux/build_aarch64_wheel.py

+1
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def start_build(host: RemoteHost, *,
555555
print("build pytorch with mkldnn+acl backend")
556556
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
557557
host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git")
558+
host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
558559
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}")
559560
print('Repair the wheel')
560561
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner
2+
product op
3+
4+
---
5+
src/cpu/aarch64/acl_inner_product.hpp | 8 ++++++--
6+
src/cpu/cpu_inner_product_list.cpp | 4 ++++
7+
2 files changed, 10 insertions(+), 2 deletions(-)
8+
9+
diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
10+
index a2be164f0..eca56b289 100644
11+
--- a/src/cpu/aarch64/acl_inner_product.hpp
12+
+++ b/src/cpu/aarch64/acl_inner_product.hpp
13+
@@ -99,9 +99,13 @@ struct acl_inner_product_fwd_t : public primitive_t {
14+
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
15+
&& attr()->has_default_values(
16+
primitive_attr_t::skip_mask_t::post_ops, f32);
17+
+ const bool is_fp32_bf16_ok
18+
+ = expect_data_types(f32, bf16, f32, f32, undef)
19+
+ && attr()->has_default_values(
20+
+ primitive_attr_t::skip_mask_t::post_ops, f32);
21+
const bool ok = is_fwd() && !has_zero_dim_memory()
22+
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
23+
- && weights_md_.format_kind == format_kind::any
24+
+ && utils::one_of(
25+
+ true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
26+
&& set_default_params() == status::success;
27+
28+
if (!ok) return status::unimplemented;
29+
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
30+
index fdd7b1776..5a3dc1ea7 100644
31+
--- a/src/cpu/cpu_inner_product_list.cpp
32+
+++ b/src/cpu/cpu_inner_product_list.cpp
33+
@@ -83,6 +83,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
34+
CPU_INSTANCE(ref_inner_product_fwd_t)
35+
nullptr,
36+
}},
37+
+ {{forward, f32, bf16, f32}, {
38+
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
39+
+ nullptr,
40+
+ }},
41+
{{backward_data, f32, f32, f32}, REG_BWD_PK({
42+
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
43+
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
44+
--
45+
2.34.1
46+

0 commit comments

Comments
 (0)