@@ -224,19 +224,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d() const {
224224 };
225225
226226 if (jpp.simple_alg ) {
227-
228- int back_pad = (jpp.od - 1 ) * jpp.stride_d + jpp.kd
227+ const int back_pad = (jpp.od - 1 ) * jpp.stride_d + jpp.kd
229228 - jpp.f_pad - jpp.id ;
230- // zero-out untouched portions of diff_src (when back_pad is negative)
231- if (back_pad < 0 )
232- parallel_nd (jpp.mb , jpp.nb_c , -back_pad, jpp.ih , jpp.iw ,
233- [&](int n, int b_c, int id_e, int ih, int iw) {
234- int id_s = jpp.id + back_pad;
235- auto ds = &diff_src[diff_src_d.blk_off (n, b_c,
236- id_s + id_e, ih, iw)];
237- for (int i = 0 ; i < jpp.c_block ; ++i)
238- ds[i] = (data_t )0 .f ;
239- });
240229
241230 parallel_nd (jpp.mb , jpp.nb_c , jpp.od ,
242231 [&](int n, int b_c, int od) {
@@ -251,6 +240,19 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d() const {
251240 ker (n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
252241 (oh == 0 ) ? zero_s : 0 , 0 );
253242 }
243+
244+ // zero-out untouched portion of diff_src when back_pad is negative
245+ if (back_pad < 0 && od == jpp.od - 1 )
246+ for (auto id_e = 0 ; id_e < -back_pad; ++id_e)
247+ for (auto ih = 0 ; ih < jpp.ih ; ++ih)
248+ for (auto iw = 0 ; iw < jpp.iw ; ++iw) {
249+ int id_s = jpp.id + back_pad;
250+ auto ds = &diff_src[diff_src_d.blk_off (n, b_c,
251+ id_s + id_e, ih, iw)];
252+ for (int i = 0 ; i < jpp.c_block ; ++i)
253+ ds[i] = (data_t )0 .f ;
254+ }
255+
254256 });
255257 } else {
256258 ptrdiff_t nelems = (ptrdiff_t )jpp.mb * (ptrdiff_t )jpp.c
0 commit comments