Skip to content

Commit 2a6cb12

Browse files
committed
cpu: conv: bwd_d: fix r_overflow1 bug
1 parent ae14f5e commit 2a6cb12

File tree

4 files changed

+49
-27
lines changed

4 files changed

+49
-27
lines changed

src/cpu/jit_avx2_conv_kernel_f32.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,8 @@ void jit_avx2_conv_bwd_data_kernel_f32::generate() {
757757
int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
758758
int r_overflow = nstl::max(0, (jcp.kw - 1
759759
- nstl::max(0, jcp.r_pad)) / jcp.stride_w);
760-
int r_overflow1 = nstl::max(0, (jcp.kw - 1
761-
- nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
760+
int r_overflow1 = nstl::max(
761+
0, (jcp.kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w);
762762

763763
int n_oi = jcp.iw / jcp.ur_w;
764764
if (r_overflow1 > 0)
@@ -946,13 +946,16 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
946946

947947
jcp.ur_w_tail = jcp.iw % jcp.ur_w;
948948

949-
int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
950-
- nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
951-
/* maximum 1 ur_w block with r_overflow so far */
952-
if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
953-
return status::unimplemented;
954-
955-
if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
949+
int r_overflow_no_tail = nstl::max(
950+
0, (jcp.kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w);
951+
bool tails_not_ok = false
952+
/* maximum 1 ur_w block with r_overflow so far */
953+
|| r_overflow_no_tail * jcp.stride_w > jcp.ur_w
954+
/* ur_w must be a multiple of stride */
955+
|| ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
956+
/* r_pad must not extend beyond ur_w_tail */
957+
|| ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
958+
if (tails_not_ok)
956959
return status::unimplemented;
957960

958961
return status::success;

src/cpu/jit_avx512_common_conv_kernel.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,8 +2478,8 @@ void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
24782478
int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
24792479
int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
24802480
- nstl::max(0, jcp.r_pad)) / stride_w);
2481-
int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
2482-
- nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
2481+
int r_overflow1 = nstl::max(
2482+
0, ((kw - 1) * dilate_w - jcp.r_pad - ur_w_tail) / stride_w);
24832483

24842484
int n_oi = iw / ur_w;
24852485
if (r_overflow1 > 0) n_oi--;
@@ -2639,8 +2639,9 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
26392639
}
26402640
int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
26412641
- jcp.l_pad) / jcp.stride_w);
2642-
int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2643-
- nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
2642+
int r_overflow1 = nstl::max(0,
2643+
((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.r_pad - jcp.iw % jcp.ur_w)
2644+
/ jcp.stride_w);
26442645
int n_oi = jcp.iw / jcp.ur_w;
26452646
if (r_overflow1 > 0) n_oi--;
26462647

@@ -2782,11 +2783,17 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
27822783

27832784
if (l_overflow * jcp.stride_w > jcp.ur_w)
27842785
return status::unimplemented;
2785-
int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2786-
- nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
2787-
if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
2788-
return status::unimplemented;
2789-
if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2786+
int r_overflow_no_tail = nstl::max(0,
2787+
((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.r_pad - jcp.ur_w_tail)
2788+
/ jcp.stride_w);
2789+
bool tails_not_ok = false
2790+
/* maximum 1 ur_w block with r_overflow so far */
2791+
|| r_overflow_no_tail * jcp.stride_w > jcp.ur_w
2792+
/* ur_w must be a multiple of stride */
2793+
|| ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2794+
/* r_pad must not extend beyond ur_w_tail */
2795+
|| ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
2796+
if (tails_not_ok)
27902797
return status::unimplemented;
27912798

27922799
pick_loop_order(jcp);

src/cpu/jit_avx512_core_bf16_conv_kernel.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -950,8 +950,8 @@ void jit_avx512_core_bf16_bwd_data_kernel::generate()
950950
int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
951951
int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
952952
- nstl::max(0, jcp.r_pad)) / stride_w);
953-
int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
954-
- nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
953+
int r_overflow1 = nstl::max(
954+
0, ((kw - 1) * dilate_w - jcp.r_pad - ur_w_tail) / stride_w);
955955

956956
int n_oi = iw / ur_w;
957957
if (r_overflow1 > 0) n_oi--;
@@ -1095,8 +1095,9 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(
10951095
reserved */
10961096
int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
10971097
- jcp.l_pad) / jcp.stride_w);
1098-
int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1099-
- nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
1098+
int r_overflow1 = nstl::max(0,
1099+
((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.r_pad - jcp.iw % jcp.ur_w)
1100+
/ jcp.stride_w);
11001101
int n_oi = jcp.iw / jcp.ur_w;
11011102
if (r_overflow1 > 0) n_oi--;
11021103

@@ -1153,11 +1154,17 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(
11531154

11541155
if (l_overflow * jcp.stride_w > jcp.ur_w)
11551156
return status::unimplemented;
1156-
int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1157-
- nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
1158-
if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
1159-
return status::unimplemented;
1160-
if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1157+
int r_overflow_no_tail = nstl::max(0,
1158+
((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.r_pad - jcp.ur_w_tail)
1159+
/ jcp.stride_w);
1160+
bool tails_not_ok = false
1161+
/* maximum 1 ur_w block with r_overflow so far */
1162+
|| r_overflow_no_tail * jcp.stride_w > jcp.ur_w
1163+
/* ur_w must be a multiple of stride */
1164+
|| ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1165+
/* r_pad must not extend beyond ur_w_tail */
1166+
|| ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
1167+
if (tails_not_ok)
11611168
return status::unimplemented;
11621169

11631170
pick_loop_order(jcp);

tests/benchdnn/inputs/test_conv_regression_general

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,8 @@ mb1_g2ic6oc16_ih5oh5kh1sh1dh0ph0_iw5ow5kw1sw1dw0pw0
9898

9999
# MFDNN-1968 1x1 conv asymmetrical strides with stride_h=1
100100
--reset --dir=BWD_D mb1ic2ih10iw8oc4oh10ow4kh1kw1sh1sw2ph0pw0dh0dw0nS
101+
102+
# github #542
103+
# r_overflow1 corner case
104+
--reset --dir=bwd_d mb1_ic9oc3_ih32oh10kh5sh3dh0ph1
105+

0 commit comments

Comments
 (0)