@@ -57,16 +57,19 @@ void gemm_convolution_fwd_t::execute_forward() const {
5757 const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id ;
5858 const size_t weights_oc_size = jcp.ic * jcp.ks ;
5959 const size_t weights_g_size = weights_oc_size * jcp.oc ;
60+ const bool is_problem_3d = pd ()->ndims () == 5 ;
6061
6162 assert (IMPLICATION (
62- jcp.id != 1 , jcp.os_block == jcp.os && jcp.ic_block == jcp.ic ));
63-
64- if (jcp.im2col_sz && jcp.id != 1 )
65- parallel_nd (jcp.im2col_sz * jcp.nthr ,
66- [&](ptrdiff_t i) { col[i] = (data_t )0 ; });
63+ is_problem_3d, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic ));
6764
6865 parallel (jcp.nthr , [&](const int ithr, const int nthr) {
6966 data_t *_col = col + (ptrdiff_t )ithr * jcp.im2col_sz ;
67+ if (is_problem_3d) {
68+ // jit_gemm_convolution_utils::im2col_3d() requires external
69+ // data initialization by zeroes
70+ for (ptrdiff_t i = 0 ; i < jcp.im2col_sz ; i++)
71+ _col[i] = (data_t )0 ;
72+ }
7073
7174 auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
7275 im_pos_t &step, const im_pos_t &end) {
@@ -82,7 +85,7 @@ void gemm_convolution_fwd_t::execute_forward() const {
8285 prev = curr;
8386
8487 if (jcp.im2col_sz && do_im2col) {
85- if (jcp. id == 1 )
88+ if (!is_problem_3d )
8689 jit_gemm_convolution_utils::im2col<float >(
8790 jcp, _src, _col, curr.sp , step.sp , curr.ic , step.ic );
8891 else
@@ -155,7 +158,7 @@ void gemm_convolution_fwd_t::execute_forward() const {
155158 im_pos_t start, end;
156159 end.ic = jcp.ic ;
157160
158- if (jcp. id == 1 ) {
161+ if (!is_problem_3d ) {
159162 const int sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os ;
160163 balance2D (nthr, ithr, sp_work, start.sp , end.sp , jcp.oc , start.oc ,
161164 end.oc , jcp.nthr_oc );
@@ -218,11 +221,7 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
218221 const int LDC = jcp.im2col_sz ? m : M;
219222
220223 const size_t work_amount = (size_t )jcp.ngroups * jcp.mb ;
221-
222- if (jcp.id > 1 ) {
223- const ptrdiff_t diff_src_sz = (ptrdiff_t )(work_amount * src_step);
224- parallel_nd (diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t )0 ; });
225- }
224+ const bool is_problem_3d = pd ()->ndims () == 5 ;
226225
227226 parallel (jcp.nthr , [&](const int ithr, const int nthr) {
228227 data_t *_col = col + (ptrdiff_t )ithr * jcp.im2col_sz ;
@@ -234,6 +233,13 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
234233 for (size_t iwork = start; iwork < end; ++iwork) {
235234
236235 data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
236+ if (is_problem_3d && jcp.im2col_sz > 0 ) {
237+ // jit_gemm_convolution_utils::col2im_3d() assumes that the
238+ // accumulator is initialized by zeroes
239+ for (size_t i = 0 ; i < src_step; i++)
240+ _diff_src[i] = (data_t )0 ;
241+ }
242+
237243 const data_t *_weights = weights + g * weights_g_size;
238244 for (int od = 0 ; od < jcp.od ; ++od) {
239245 const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
@@ -245,7 +251,7 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
245251 jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
246252
247253 if (jcp.im2col_sz ) {
248- if (jcp. id == 1 )
254+ if (!is_problem_3d )
249255 jit_gemm_convolution_utils::col2im (jcp, _col,
250256 _diff_src);
251257 else
@@ -278,9 +284,7 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
278284 const int N = jcp.oc ;
279285 const int M = jcp.ic * jcp.ks ;
280286 const int LDA = jcp.im2col_sz ? k : K;
281-
282- parallel_nd (jcp.im2col_sz * jcp.nthr ,
283- [&](ptrdiff_t i) { col[i] = (data_t )0 ; });
287+ const bool is_problem_3d = pd ()->ndims () == 5 ;
284288
285289 parallel (jcp.nthr , [&](const int ithr, const int nthr) {
286290 int ithr_g, nthr_g, ithr_mb, nthr_mb;
@@ -300,6 +304,13 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
300304 assert (IMPLICATION ((g_end - g_start) > 1 , need_reduction == 0 ));
301305
302306 data_t *_col = col + (ptrdiff_t )ithr * jcp.im2col_sz ;
307+ if (is_problem_3d) {
308+ // jit_gemm_convolution_utils::im2col_3d() requires external
309+ // data initialization by zeroes
310+ for (ptrdiff_t i = 0 ; i < jcp.im2col_sz ; i++)
311+ _col[i] = (data_t )0 ;
312+ }
313+
303314 data_t *weights_reduce_base = wei_reduction
304315 + ithr_g * nthr_mb * weights_g_size;
305316 data_t *weights_reduce = weights_reduce_base
@@ -315,7 +326,7 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
315326 + (mb*jcp.ngroups +g)*dst_step + od * k;
316327
317328 if (jcp.im2col_sz ) {
318- if (jcp. id == 1 )
329+ if (!is_problem_3d )
319330 jit_gemm_convolution_utils::im2col<float >(
320331 jcp, _src, _col, 0 , jcp.os , 0 , jcp.ic );
321332 else
0 commit comments