@@ -423,8 +423,10 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
423423
424424template <SIMPLE_REORDER_TEMPL_DECL>
425425struct simple_reorder_impl <SIMPLE_REORDER_TEMPL_CALL,
426- typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
427- && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
426+ typename utils::enable_if<true
427+ && (format_traits<fmt_i>::blk_fmt == bf::_4c
428+ || format_traits<fmt_i>::blk_fmt == bf::_8c)
429+ && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
428430{
429431 static bool is_applicable (const memory_desc_wrapper &input_d,
430432 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
@@ -439,47 +441,47 @@ typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
439441
440442 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1 ;
441443 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3 ;
442- constexpr int blksize_16 = format_traits<fmt_o>::blk_size;
443- constexpr int blksize_8 = format_traits<fmt_i>::blk_size;
444+ constexpr int blksize_fmt_o = format_traits<fmt_o>::blk_size;
445+ constexpr int blksize_fmt_i = format_traits<fmt_i>::blk_size;
444446 constexpr int ic_mult = order_keep ? 2 : 1 ;
445447 constexpr int oc_mult = order_keep ? 1 : 2 ;
446448
447- const auto &nchw8c_d = order_keep ? input_d : output_d;
449+ const auto &fmt_i_d = order_keep ? input_d : output_d;
448450 const auto &dims = input_d.dims ();
449451 const auto &pdims = order_keep ? output_d.blocking_desc ().padding_dims
450452 : input_d.blocking_desc ().padding_dims ;
451- const auto stride_8c = nchw8c_d .blocking_desc ().strides [0 ];
453+ const auto stride_fmt_i = fmt_i_d .blocking_desc ().strides [0 ];
452454
453455 const int C = dims[1 ];
454456 const int D = is_3d ? dims[2 ] : 1 ;
455457 const int H = is_1d ? 1 : dims[2 + is_3d];
456458 const int W = dims[3 + is_3d - is_1d];
457459
458460 auto ker = [&](const data_t <type_i> *i, data_t <type_o> *o,
459- const int block_16 ) {
460- const int nb = (block_16 - 1 ) / blksize_8 + 1 ;
461+ const int block_fmt_o ) {
462+ const int nb = (block_fmt_o - 1 ) / blksize_fmt_i + 1 ;
461463 if (alpha == 1.0 && beta == 0.0 ) {
462464 for (int b = 0 ; b < nb; ++b) {
463- const ptrdiff_t i_off = order_keep ? b * stride_8c [1 ]
464- : b * blksize_8 ;
465- const ptrdiff_t o_off = order_keep ? b * blksize_8
466- : b * stride_8c [1 ];
467- const int block_8 = nstl::min (blksize_8 ,
468- block_16 - b * blksize_8 );
469- for (int c = 0 ; c < block_8 ; ++c) {
465+ const ptrdiff_t i_off = order_keep ? b * stride_fmt_i [1 ]
466+ : b * blksize_fmt_i ;
467+ const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
468+ : b * stride_fmt_i [1 ];
469+ const int block_fmt_i = nstl::min (blksize_fmt_i ,
470+ block_fmt_o - b * blksize_fmt_i );
471+ for (int c = 0 ; c < block_fmt_i ; ++c) {
470472 o[o_off + c] = _qz_a1b0<type_i, type_o>()(
471473 i[i_off + c], rmode);
472474 }
473475 }
474476 } else {
475477 for (int b = 0 ; b < nb; ++b) {
476- const ptrdiff_t i_off = order_keep ? b * stride_8c [1 ]
477- : b * blksize_8 ;
478- const ptrdiff_t o_off = order_keep ? b * blksize_8
479- : b * stride_8c [1 ];
480- const int block_8 = nstl::min (blksize_8 ,
481- block_16 - b * blksize_8 );
482- for (int c = 0 ; c < block_8 ; ++c) {
478+ const ptrdiff_t i_off = order_keep ? b * stride_fmt_i [1 ]
479+ : b * blksize_fmt_i ;
480+ const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
481+ : b * stride_fmt_i [1 ];
482+ const int block_fmt_i = nstl::min (blksize_fmt_i ,
483+ block_fmt_o - b * blksize_fmt_i );
484+ for (int c = 0 ; c < block_fmt_i ; ++c) {
483485 o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
484486 o[o_off + c], alpha, beta, rmode);
485487 }
@@ -491,12 +493,12 @@ typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
491493 ( is_1d ? (md).blk_off (n, c, w) \
492494 : is_3d ? (md).blk_off (n, c, d, h, w) : (md).blk_off (n, c, h, w))
493495
494- parallel_nd (dims[0 ], pdims[1 ] / blksize_16 , D, H, W,
496+ parallel_nd (dims[0 ], pdims[1 ] / blksize_fmt_o , D, H, W,
495497 [&](int n, int nb_c, int d, int h, int w) {
496498 auto i = &input[data_blk_off (input_d, n, ic_mult * nb_c, d, h, w)];
497499 auto o = &output[data_blk_off (output_d, n, oc_mult * nb_c, d, h, w)];
498- const int block_16 = nstl::min (blksize_16 , C - nb_c * blksize_16 );
499- ker (i, o, block_16 );
500+ const int block_fmt_o = nstl::min (blksize_fmt_o , C - nb_c * blksize_fmt_o );
501+ ker (i, o, block_fmt_o );
500502 });
501503
502504# undef data_blk_off
0 commit comments