Skip to content

Commit 24eda67

Browse files
committed
fixup: src: cpu: support channel tails in depthwise deconv
1 parent f310ded commit 24eda67

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
371371
int _start = (jcp.signed_input) ? 0 : jj_start;
372372
int _end = (jcp.signed_input) ? ur_w : jj_end;
373373

374-
int tail_size = jcp.ic_without_padding % 4;
374+
int tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
375+
: jcp.ic_without_padding % 4;
375376
int n_ic_blocks = jcp.is_depthwise
376377
? 1
377378
: (last_ic_block_flag & ~no_last_block ? div_up(
@@ -393,7 +394,12 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
393394
if (jj >= jj_start && jj < jj_end
394395
&& ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
395396
if (jcp.is_depthwise) {
396-
vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
397+
auto zmm_src = zmm_inp(jj, jcp.nb_oc_blocking);
398+
if (tail_size != 0) {
399+
assert(jcp.nb_oc_blocking == 1);
400+
zmm_src = zmm_src | ktail_mask | T_z;
401+
}
402+
vpmovzxbd(zmm_src,
397403
EVEX_compress_addr(
398404
aux_reg_src, aux_src_off));
399405
} else if ((last_ic_block_flag & last_sp_block)
@@ -875,8 +881,15 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
875881
: jcp.oc_without_padding % jcp.oc_block;
876882
int mask = (1 << tail_size) - 1;
877883
Reg32 regw_tmp = reg_nur_w.cvt32();
884+
Label skip_tail_mask;
885+
if (jcp.is_depthwise) {
886+
kxnorw(ktail_mask, ktail_mask, ktail_mask);
887+
cmp(dword[param1 + GET_OFF(oc_blocks)], jcp.nb_ch - 1);
888+
jne(skip_tail_mask, T_NEAR);
889+
}
878890
mov(regw_tmp, mask);
879891
kmovw(ktail_mask, regw_tmp);
892+
L(skip_tail_mask);
880893
}
881894

882895
mov(reg_src, ptr[param1 + GET_OFF(src)]);

0 commit comments

Comments
 (0)