Skip to content

Commit 0d55f19

Browse files
committed
fixup: src: cpu: fixed pooling correctness for padding corner case
corrected crosspoint of choosing pooling index workspace data type as value u8_max is reserved for designation of invalid index for u8 data type
1 parent f02096a commit 0d55f19

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

src/cpu/cpu_pooling_pd.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ inline data_type_t pooling_index_data_type(const pooling_desc_t *p) {
3737
/* the simplest way to express 256... */
3838
const int u8_max =
3939
numeric_limits<typename prec_traits<data_type::u8>::type>::max();
40+
/* value u8_max in the case of data_type::u8 is reserved for
41+
designation of invalid index when pooling window is fully placed
42+
outside of source domain */
4043
if( p->src_desc.ndims == 5 || p->diff_src_desc.ndims == 5 ) {
41-
return p->kernel[0] * p->kernel[1] * p->kernel[2] <= u8_max
44+
return p->kernel[0] * p->kernel[1] * p->kernel[2] < u8_max
4245
? data_type::u8 : data_type::s32;
4346
} else {
44-
return p->kernel[0] * p->kernel[1] <= u8_max
47+
return p->kernel[0] * p->kernel[1] < u8_max
4548
? data_type::u8 : data_type::s32;
4649
}
4750
}

src/cpu/nchw_pooling.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace mkldnn {
2929
namespace impl {
3030
namespace cpu {
3131

32+
using namespace nstl;
33+
3234
template <impl::data_type_t data_type>
3335
void nchw_pooling_fwd_t<data_type>::execute_forward() const {
3436
using namespace alg_kind;
@@ -77,9 +79,11 @@ void nchw_pooling_fwd_t<data_type>::execute_forward() const {
7779
+ (size_t)OW * oh
7880
+ (size_t)ow;
7981
if (ws_dt == data_type::u8) {
82+
const int u8_max = numeric_limits<
83+
typename prec_traits<data_type::u8>::type>::max();
8084
if (value == -1)
81-
value = 255;
82-
assert(0 <= value && value <= 255);
85+
value = u8_max;
86+
assert(0 <= value && value <= u8_max);
8387
ws[ws_offset] = value;
8488
} else
8589
reinterpret_cast<int *>(ws)[ws_offset] = value;
@@ -236,8 +240,10 @@ void nchw_pooling_bwd_t<data_type>::execute_backward() const {
236240

237241
const int index = ws_d.data_type() == data_type::u8
238242
? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
239-
if (index == -1
240-
|| (ws_d.data_type() == data_type::u8 && index == 255))
243+
const int invalid_index_value = ws_d.data_type() == data_type::u8
244+
? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
245+
: -1;
246+
if (index == invalid_index_value)
241247
return; // corner case: pool window is outside of real input domain
242248
// for this point, do nothing
243249

src/cpu/ref_pooling.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace mkldnn {
2929
namespace impl {
3030
namespace cpu {
3131

32+
using namespace nstl;
33+
3234
template <data_type_t data_type, data_type_t acc_type>
3335
void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() const {
3436
using namespace alg_kind;
@@ -71,11 +73,13 @@ void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() const {
7173
if (ws) {
7274
assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
7375
size_t offset = is_3d
74-
? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);;
76+
? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);
7577
if (ws_dt == data_type::u8) {
78+
const int u8_max = numeric_limits<
79+
typename prec_traits<data_type::u8>::type>::max();
7680
if (value == -1)
77-
value = 255;
78-
assert(0 <= value && value <= 255);
81+
value = u8_max;
82+
assert(0 <= value && value <= u8_max);
7983
ws[offset] = value;
8084
} else
8185
reinterpret_cast<int *>(ws)[offset] = value;
@@ -257,8 +261,10 @@ void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() const {
257261
const int index = ws_d.data_type() == data_type::u8
258262
? (int)ws[ws_off] : ((int *)ws)[ws_off];
259263

260-
if (index == -1
261-
|| (ws_d.data_type() == data_type::u8 && index == 255))
264+
const int invalid_index_value = ws_d.data_type() == data_type::u8
265+
? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
266+
: -1;
267+
if (index == invalid_index_value)
262268
return; // corner case: pool window is outside of real input domain
263269
// for this point, do nothing
264270

@@ -313,8 +319,10 @@ void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() const {
313319
const int index = ws_d.data_type() == data_type::u8
314320
? (int)ws[ws_off] : ((int *)ws)[ws_off];
315321

316-
if (index == -1
317-
|| (ws_d.data_type() == data_type::u8 && index == 255))
322+
const int invalid_index_value = ws_d.data_type() == data_type::u8
323+
? numeric_limits<typename prec_traits<data_type::u8>::type>::max()
324+
: -1;
325+
if (index == invalid_index_value)
318326
return; // corner case: pool window is outside of real input domain
319327
// for this point, do nothing
320328

0 commit comments

Comments
 (0)