Skip to content

Commit 07ddf70

Browse files
committed
cpu: reorder: extend simple reorder to support nChw4c -> nChw16c
1 parent 72b7cbe commit 07ddf70

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

src/cpu/cpu_reorder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ static const rpd_create_f cpu_reorder_impl_list[] = {
158158

159159
/* fp32: blocked <-> blocked with tail */
160160
REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
161+
REG_SR_BIDIR(f32, nChw4c, f32, nChw16c),
161162
REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
162163
REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
163164

src/cpu/simple_reorder.hpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,10 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
423423

424424
template <SIMPLE_REORDER_TEMPL_DECL>
425425
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
426-
typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
427-
&& format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
426+
typename utils::enable_if<true
427+
&& (format_traits<fmt_i>::blk_fmt == bf::_4c
428+
|| format_traits<fmt_i>::blk_fmt == bf::_8c)
429+
&& format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
428430
{
429431
static bool is_applicable(const memory_desc_wrapper &input_d,
430432
const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
@@ -439,47 +441,47 @@ typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
439441

440442
constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
441443
constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
442-
constexpr int blksize_16 = format_traits<fmt_o>::blk_size;
443-
constexpr int blksize_8 = format_traits<fmt_i>::blk_size;
444+
constexpr int blksize_fmt_o = format_traits<fmt_o>::blk_size;
445+
constexpr int blksize_fmt_i = format_traits<fmt_i>::blk_size;
444446
constexpr int ic_mult = order_keep ? 2 : 1;
445447
constexpr int oc_mult = order_keep ? 1 : 2;
446448

447-
const auto &nchw8c_d = order_keep ? input_d : output_d;
449+
const auto &fmt_i_d = order_keep ? input_d : output_d;
448450
const auto &dims = input_d.dims();
449451
const auto &pdims = order_keep ? output_d.blocking_desc().padding_dims
450452
: input_d.blocking_desc().padding_dims;
451-
const auto stride_8c = nchw8c_d.blocking_desc().strides[0];
453+
const auto stride_fmt_i = fmt_i_d.blocking_desc().strides[0];
452454

453455
const int C = dims[1];
454456
const int D = is_3d ? dims[2] : 1;
455457
const int H = is_1d ? 1 : dims[2 + is_3d];
456458
const int W = dims[3 + is_3d - is_1d];
457459

458460
auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
459-
const int block_16) {
460-
const int nb = (block_16 - 1) / blksize_8 + 1;
461+
const int block_fmt_o) {
462+
const int nb = (block_fmt_o - 1) / blksize_fmt_i + 1;
461463
if (alpha == 1.0 && beta == 0.0) {
462464
for (int b = 0; b < nb; ++b) {
463-
const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
464-
: b * blksize_8;
465-
const ptrdiff_t o_off = order_keep ? b * blksize_8
466-
: b * stride_8c[1];
467-
const int block_8 = nstl::min(blksize_8,
468-
block_16 - b * blksize_8);
469-
for (int c = 0; c < block_8; ++c) {
465+
const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
466+
: b * blksize_fmt_i;
467+
const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
468+
: b * stride_fmt_i[1];
469+
const int block_fmt_i = nstl::min(blksize_fmt_i,
470+
block_fmt_o - b * blksize_fmt_i);
471+
for (int c = 0; c < block_fmt_i; ++c) {
470472
o[o_off + c] = _qz_a1b0<type_i, type_o>()(
471473
i[i_off + c], rmode);
472474
}
473475
}
474476
} else {
475477
for (int b = 0; b < nb; ++b) {
476-
const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
477-
: b * blksize_8;
478-
const ptrdiff_t o_off = order_keep ? b * blksize_8
479-
: b * stride_8c[1];
480-
const int block_8 = nstl::min(blksize_8,
481-
block_16 - b * blksize_8);
482-
for (int c = 0; c < block_8; ++c) {
478+
const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
479+
: b * blksize_fmt_i;
480+
const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
481+
: b * stride_fmt_i[1];
482+
const int block_fmt_i = nstl::min(blksize_fmt_i,
483+
block_fmt_o - b * blksize_fmt_i);
484+
for (int c = 0; c < block_fmt_i; ++c) {
483485
o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
484486
o[o_off + c], alpha, beta, rmode);
485487
}
@@ -491,12 +493,12 @@ typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
491493
( is_1d ? (md).blk_off(n, c, w) \
492494
: is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
493495

494-
parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
496+
parallel_nd(dims[0], pdims[1] / blksize_fmt_o, D, H, W,
495497
[&](int n, int nb_c, int d, int h, int w) {
496498
auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
497499
auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
498-
const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
499-
ker(i, o, block_16);
500+
const int block_fmt_o = nstl::min(blksize_fmt_o, C - nb_c * blksize_fmt_o);
501+
ker(i, o, block_fmt_o);
500502
});
501503

502504
# undef data_blk_off

0 commit comments

Comments
 (0)