Skip to content

Commit 3d343bc

Browse files
committed
src: cpu: gemm conv: corrected 3d problem dispatching
removed extra parallel sections for data initialization by zeros
1 parent 03cfec9 commit 3d343bc

File tree

4 files changed

+61
-31
lines changed

4 files changed

+61
-31
lines changed

src/cpu/gemm_bf16_convolution.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,24 +306,27 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward() const {
306306
const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
307307
const size_t dst_step = (size_t)jcp.oc * M;
308308
const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
309+
const bool is_problem_3d = pd()->ndims() == 5;
309310

310311
assert(IMPLICATION(
311-
jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
312+
is_problem_3d, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
312313
assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
313314

314315
const int K = jcp.ic * jcp.ks;
315316
const int N = jcp.oc;
316317

317-
if (jcp.im2col_sz && jcp.id != 1)
318-
parallel_nd(jcp.im2col_sz * jcp.nthr,
319-
[&](ptrdiff_t i) { col[i] = (src_data_t)0; });
320-
321318
const int nb_oh = div_up(jcp.oh, jcp.oh_block);
322319
const int nb_ow = div_up(jcp.ow, jcp.ow_block);
323320
const size_t work_amount = (size_t)jcp.ngroups
324321
* jcp.mb * jcp.od * nb_oh * nb_ow;
325322
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
326323
src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
324+
if (is_problem_3d) {
325+
// jit_gemm_convolution_utils::im2col_3d() requires external
326+
// data initialization by zeroes
327+
for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
328+
_col[i] = (src_data_t)0;
329+
}
327330

328331
int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
329332
size_t start = 0, end = 0;
@@ -340,7 +343,7 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward() const {
340343
const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
341344
const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
342345
if (jcp.im2col_sz) {
343-
if (jcp.id == 1)
346+
if (!is_problem_3d)
344347
jit_gemm_convolution_utils::im2col<src_data_t>(
345348
jcp, _src, _col, 0, jcp.os, 0, jcp.ic);
346349
else
@@ -403,6 +406,7 @@ void gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
403406
const int LDC = jcp.im2col_sz ? m : M;
404407

405408
const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
409+
const bool is_problem_3d = pd()->ndims() == 5;
406410

407411
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
408412
acc_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
@@ -419,7 +423,7 @@ void gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
419423
? acc_base + ithr * rnd_up(src_step, 16)
420424
: (acc_data_t *)diff_src_local;
421425

422-
if (jcp.id > 1 && jcp.im2col_sz > 0) {
426+
if (is_problem_3d && jcp.im2col_sz > 0) {
423427
// jit_gemm_convolution_utils::col2im_3d() assumes that the
424428
// accumulator is initialized by zeroes
425429
for (size_t i = 0; i < src_step; i++)
@@ -437,7 +441,7 @@ void gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
437441
jcp.im2col_sz ? _col: acc + od * m, &LDC);
438442

439443
if (jcp.im2col_sz) {
440-
if (jcp.id == 1)
444+
if (!is_problem_3d)
441445
jit_gemm_convolution_utils::col2im(jcp, _col,
442446
acc);
443447
else
@@ -548,9 +552,7 @@ void gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
548552
const int N = jcp.oc;
549553
const int M = jcp.ic * jcp.ks;
550554
const int LDA = jcp.im2col_sz ? k : K;
551-
552-
parallel_nd(jcp.im2col_sz * jcp.nthr,
553-
[&](ptrdiff_t i) { col[i] = (src_data_t)0; });
555+
const bool is_problem_3d = pd()->ndims() == 5;
554556

555557
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
556558
int ithr_g, nthr_g, ithr_mb, nthr_mb;
@@ -570,6 +572,13 @@ void gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
570572
assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
571573

572574
src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
575+
if (is_problem_3d) {
576+
// jit_gemm_convolution_utils::im2col_3d() requires external
577+
// data initialization by zeroes
578+
for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
579+
_col[i] = (src_data_t)0;
580+
}
581+
573582
acc_data_t *weights_reduce_base = wei_reduction
574583
+ ithr_g * nthr_mb * weights_g_size;
575584
acc_data_t *weights_reduce = weights_reduce_base
@@ -586,7 +595,7 @@ void gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
586595
+ (mb * jcp.ngroups + g) * dst_step + od * k;
587596

588597
if (jcp.im2col_sz) {
589-
if (jcp.id == 1)
598+
if (!is_problem_3d)
590599
jit_gemm_convolution_utils::im2col<src_data_t>(
591600
jcp, _src, _col, 0, jcp.os, 0, jcp.ic);
592601
else

src/cpu/gemm_convolution.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,19 @@ void gemm_convolution_fwd_t::execute_forward() const {
5757
const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
5858
const size_t weights_oc_size = jcp.ic * jcp.ks;
5959
const size_t weights_g_size = weights_oc_size * jcp.oc;
60+
const bool is_problem_3d = pd()->ndims() == 5;
6061

6162
assert(IMPLICATION(
62-
jcp.id != 1, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic));
63-
64-
if (jcp.im2col_sz && jcp.id != 1)
65-
parallel_nd(jcp.im2col_sz * jcp.nthr,
66-
[&](ptrdiff_t i) { col[i] = (data_t)0; });
63+
is_problem_3d, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic));
6764

6865
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
6966
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
67+
if (is_problem_3d) {
68+
// jit_gemm_convolution_utils::im2col_3d() requires external
69+
// data initialization by zeroes
70+
for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
71+
_col[i] = (data_t)0;
72+
}
7073

7174
auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
7275
im_pos_t &step, const im_pos_t &end) {
@@ -82,7 +85,7 @@ void gemm_convolution_fwd_t::execute_forward() const {
8285
prev = curr;
8386

8487
if (jcp.im2col_sz && do_im2col) {
85-
if (jcp.id == 1)
88+
if (!is_problem_3d)
8689
jit_gemm_convolution_utils::im2col<float>(
8790
jcp, _src, _col, curr.sp, step.sp, curr.ic, step.ic);
8891
else
@@ -155,7 +158,7 @@ void gemm_convolution_fwd_t::execute_forward() const {
155158
im_pos_t start, end;
156159
end.ic = jcp.ic;
157160

158-
if (jcp.id == 1) {
161+
if (!is_problem_3d) {
159162
const int sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
160163
balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
161164
end.oc, jcp.nthr_oc);
@@ -218,11 +221,7 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
218221
const int LDC = jcp.im2col_sz ? m : M;
219222

220223
const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
221-
222-
if (jcp.id > 1) {
223-
const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step);
224-
parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; });
225-
}
224+
const bool is_problem_3d = pd()->ndims() == 5;
226225

227226
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
228227
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
@@ -234,6 +233,13 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
234233
for (size_t iwork = start; iwork < end; ++iwork) {
235234

236235
data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
236+
if (is_problem_3d && jcp.im2col_sz > 0) {
237+
// jit_gemm_convolution_utils::col2im_3d() assumes that the
238+
// accumulator is initialized by zeroes
239+
for (size_t i = 0; i < src_step; i++)
240+
_diff_src[i] = (data_t)0;
241+
}
242+
237243
const data_t *_weights = weights + g * weights_g_size;
238244
for (int od = 0; od < jcp.od; ++od) {
239245
const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
@@ -245,7 +251,7 @@ void gemm_convolution_bwd_data_t::execute_backward_data() const {
245251
jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
246252

247253
if (jcp.im2col_sz) {
248-
if (jcp.id == 1)
254+
if (!is_problem_3d)
249255
jit_gemm_convolution_utils::col2im(jcp, _col,
250256
_diff_src);
251257
else
@@ -278,9 +284,7 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
278284
const int N = jcp.oc;
279285
const int M = jcp.ic * jcp.ks;
280286
const int LDA = jcp.im2col_sz ? k : K;
281-
282-
parallel_nd(jcp.im2col_sz * jcp.nthr,
283-
[&](ptrdiff_t i) { col[i] = (data_t)0; });
287+
const bool is_problem_3d = pd()->ndims() == 5;
284288

285289
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
286290
int ithr_g, nthr_g, ithr_mb, nthr_mb;
@@ -300,6 +304,13 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
300304
assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
301305

302306
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
307+
if (is_problem_3d) {
308+
// jit_gemm_convolution_utils::im2col_3d() requires external
309+
// data initialization by zeroes
310+
for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
311+
_col[i] = (data_t)0;
312+
}
313+
303314
data_t *weights_reduce_base = wei_reduction
304315
+ ithr_g * nthr_mb * weights_g_size;
305316
data_t *weights_reduce = weights_reduce_base
@@ -315,7 +326,7 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
315326
+ (mb*jcp.ngroups+g)*dst_step + od * k;
316327

317328
if (jcp.im2col_sz) {
318-
if (jcp.id == 1)
329+
if (!is_problem_3d)
319330
jit_gemm_convolution_utils::im2col<float>(
320331
jcp, _src, _col, 0, jcp.os, 0, jcp.ic);
321332
else

src/cpu/gemm_convolution_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
652652
if (is_fwd) {
653653
const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
654654
bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz
655-
&& jcp.id == 1 && jcp.od == 1 && jcp.dilate_h == 0
655+
&& !is_3d && jcp.dilate_h == 0
656656
&& jcp.dilate_w == 0 && !is_depthwise && wei_size < L2 / 2;
657657
if (is_blocking_applicable) {
658658
// looking for oh and ow blocking
@@ -808,7 +808,7 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
808808
// gemm implementation which we cannot control
809809
bool is_blocking_applicable = true
810810
&& !is_bf16_conv // TODO: apply blocking to bf16
811-
&& jcp.id == 1 && jcp.od == 1
811+
&& !is_3d
812812
&& (!jcp.im2col_sz
813813
// spatial is small
814814
|| spatial >= max_threads * simd_w

tests/benchdnn/inputs/test_conv_regression_general

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,13 @@ mb1_g16ic16oc16_ih5oh2kh2sh3ph0
132132

133133
# MFDNN-2027 2d corner case (large kernel and top padding, negative bottom padding + stride)
134134
--dir=BWD_W --cfg=f32 g1ic16ih7iw1oc16oh1ow1kh11kw1sh2sw1ph5pw0n
135+
136+
# 3d problem dispatching
137+
--cfg=f32
138+
--dir=FWD_B g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1
139+
--dir=BWD_D g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1
140+
--dir=BWD_W g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1
141+
--cfg=bf16bf16bf16
142+
--dir=FWD_B g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1
143+
--dir=BWD_D g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1
144+
--dir=BWD_W g1ic16id1ih7iw8oc16od1oh7ow8kd3kh3kw3sd2sh1sw1pd1ph1pw1

0 commit comments

Comments
 (0)