@@ -29,6 +29,8 @@ namespace mkldnn {
2929namespace impl {
3030namespace cpu {
3131
32+ using namespace nstl ;
33+
3234template <data_type_t data_type, data_type_t acc_type>
3335void ref_pooling_fwd_t <data_type, acc_type>::execute_forward() const {
3436 using namespace alg_kind ;
@@ -71,11 +73,13 @@ void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() const {
7173 if (ws) {
7274 assert (ws_dt == data_type::u8 || ws_dt == data_type::s32);
7375 size_t offset = is_3d
74- ? ws_d.off (mb, oc, od, oh, ow) : ws_d.off (mb, oc, oh, ow);;
76+ ? ws_d.off (mb, oc, od, oh, ow) : ws_d.off (mb, oc, oh, ow);
7577 if (ws_dt == data_type::u8 ) {
78+ const int u8_max = numeric_limits<
79+ typename prec_traits<data_type::u8 >::type>::max ();
7680 if (value == -1 )
77- value = 255 ;
78- assert (0 <= value && value <= 255 );
81+ value = u8_max ;
82+ assert (0 <= value && value <= u8_max );
7983 ws[offset] = value;
8084 } else
8185 reinterpret_cast <int *>(ws)[offset] = value;
@@ -257,8 +261,10 @@ void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() const {
257261 const int index = ws_d.data_type () == data_type::u8
258262 ? (int )ws[ws_off] : ((int *)ws)[ws_off];
259263
260- if (index == -1
261- || (ws_d.data_type () == data_type::u8 && index == 255 ))
264+ const int invalid_index_value = ws_d.data_type () == data_type::u8
265+ ? numeric_limits<typename prec_traits<data_type::u8 >::type>::max ()
266+ : -1 ;
267+ if (index == invalid_index_value)
262268 return ; // corner case: pool window is outside of real input domain
263269 // for this point, do nothing
264270
@@ -313,8 +319,10 @@ void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() const {
313319 const int index = ws_d.data_type () == data_type::u8
314320 ? (int )ws[ws_off] : ((int *)ws)[ws_off];
315321
316- if (index == -1
317- || (ws_d.data_type () == data_type::u8 && index == 255 ))
322+ const int invalid_index_value = ws_d.data_type () == data_type::u8
323+ ? numeric_limits<typename prec_traits<data_type::u8 >::type>::max ()
324+ : -1 ;
325+ if (index == invalid_index_value)
318326 return ; // corner case: pool window is outside of real input domain
319327 // for this point, do nothing
320328
0 commit comments