Skip to content

Commit 05bcc9f

Browse files
committed
cpu: avx512_core: fix access to padded areas in dw bwd/wu convs
1 parent 5881410 commit 05bcc9f

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/cpu/jit_avx512_core_bf16_dw_conv_kernel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
557557
if (i_ur == 0) {
558558
for (int c = 0; c < input_overlap; ++c) {
559559
int off_input = (c - pad_offset) * jcp.ch_block;
560+
if (off_input < 0 && unroll_w == jcp.ow)
561+
continue;
560562
Zmm zmm_input = get_input_reg(c);
561563
vpmovzxwd(zmm_input,
562564
ptr[reg_tmp_input + off_input * jcp.typesize_in]);
@@ -565,6 +567,8 @@ inline void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
565567
for (int c = 0; c < cascade_input; ++c) {
566568
int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
567569
int off_input = (overlap + c - pad_offset) * jcp.ch_block;
570+
if (off_input < 0 || overlap + c + l_pad > right_border)
571+
continue;
568572
Zmm zmm_input = get_input_reg(overlap + c);
569573
vpmovzxwd(zmm_input,
570574
ptr[reg_tmp_input + off_input * jcp.typesize_in]);

src/cpu/jit_uni_dw_conv_kernel_f32.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
557557
for (int c = 0; c < input_overlap; ++c) {
558558
int off_input
559559
= ((c - pad_offset) * reg_repeats + r) * simd_w;
560+
if (off_input < 0 && unroll_w == jcp.ow)
561+
continue;
560562
Vmm vmm_input
561563
= get_input_reg((c % jcp.kw) * reg_repeats + r);
562564
uni_vmovups(vmm_input,
@@ -568,6 +570,8 @@ inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
568570
int off_input
569571
= ((overlap + c - pad_offset) * reg_repeats + r)
570572
* simd_w;
573+
if (off_input < 0 || overlap + c + l_pad > right_border)
574+
continue;
571575
Vmm vmm_input = get_input_reg(
572576
((overlap + c) % jcp.kw) * reg_repeats + r);
573577
uni_vmovups(vmm_input,

0 commit comments

Comments
 (0)