@@ -371,7 +371,8 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
371371 int _start = (jcp.signed_input ) ? 0 : jj_start;
372372 int _end = (jcp.signed_input ) ? ur_w : jj_end;
373373
374- int tail_size = jcp.ic_without_padding % 4 ;
374+ int tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
375+ : jcp.ic_without_padding % 4 ;
375376 int n_ic_blocks = jcp.is_depthwise
376377 ? 1
377378 : (last_ic_block_flag & ~no_last_block ? div_up (
@@ -393,7 +394,12 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
393394 if (jj >= jj_start && jj < jj_end
394395 && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0 )) {
395396 if (jcp.is_depthwise ) {
396- vpmovzxbd (zmm_inp (jj, jcp.nb_oc_blocking ),
397+ auto zmm_src = zmm_inp (jj, jcp.nb_oc_blocking );
398+ if (tail_size != 0 ) {
399+ assert (jcp.nb_oc_blocking == 1 );
400+ zmm_src = zmm_src | ktail_mask | T_z;
401+ }
402+ vpmovzxbd (zmm_src,
397403 EVEX_compress_addr (
398404 aux_reg_src, aux_src_off));
399405 } else if ((last_ic_block_flag & last_sp_block)
@@ -875,8 +881,15 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
875881 : jcp.oc_without_padding % jcp.oc_block ;
876882 int mask = (1 << tail_size) - 1 ;
877883 Reg32 regw_tmp = reg_nur_w.cvt32 ();
884+ Label skip_tail_mask;
885+ if (jcp.is_depthwise ) {
886+ kxnorw (ktail_mask, ktail_mask, ktail_mask);
887+ cmp (dword[param1 + GET_OFF (oc_blocks)], jcp.nb_ch - 1 );
888+ jne (skip_tail_mask, T_NEAR);
889+ }
878890 mov (regw_tmp, mask);
879891 kmovw (ktail_mask, regw_tmp);
892+ L (skip_tail_mask);
880893 }
881894
882895 mov (reg_src, ptr[param1 + GET_OFF (src)]);
0 commit comments