Skip to content

Commit d89bf4b

Browse files
densamoilovtprimak
authored andcommitted
cpu: ip: add formats for gemv cases properly
1 parent e98e291 commit d89bf4b

File tree

5 files changed

+34
-1
lines changed

5 files changed

+34
-1
lines changed

src/common/format_traits.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ DECL_TRAITS(io, wei, _, 2, 0);
119119
/* wei: 3D */
120120
DECL_TRAITS(oiw, wei, _, 3, 1);
121121
DECL_TRAITS(wio, wei, _, 3, 1);
122+
DECL_TRAITS(owi, wei, _, 3, 1);
122123
DECL_TRAITS(Owi4o, wei, _4o, 3, 1);
123124
DECL_TRAITS(OIw4i4o, wei, _4i4o, 3, 1);
124125
DECL_TRAITS(Owi8o, wei, _8o, 3, 1);
@@ -140,6 +141,7 @@ DECL_TRAITS(OIw4i16o4i_s8s8, wei, _4i16o4i_s8s8, 3, 1);
140141
DECL_TRAITS(oihw, wei, _, 4, 2);
141142
DECL_TRAITS(ihwo, wei, _, 4, 2);
142143
DECL_TRAITS(hwio, wei, _, 4, 2);
144+
DECL_TRAITS(ohwi, wei, _, 4, 2);
143145
DECL_TRAITS(iohw, wei, _, 4, 2);
144146
DECL_TRAITS(hwio_s8s8, wei, _, 4, 2);
145147
DECL_TRAITS(oIhw8i, wei, _8i, 4, 2);
@@ -164,6 +166,7 @@ DECL_TRAITS(Ohwi16o, wei, _16o, 4, 2);
164166

165167
/* wei: 5D */
166168
DECL_TRAITS(dhwio, wei, _, 5, 3);
169+
DECL_TRAITS(odhwi, wei, _, 5, 3);
167170
DECL_TRAITS(oidhw, wei, _, 5, 3);
168171
DECL_TRAITS(OIdhw4i4o, wei, _4i4o, 5, 3);
169172
DECL_TRAITS(Odhwi4o, wei, _4o, 5, 3);

src/common/memory_desc_wrapper.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ status_t fill_wio(memory_desc_t &md) {
294294
return fill_nonblocked(md, perm);
295295
}
296296

297+
status_t fill_owi(memory_desc_t &md) {
298+
if (md.ndims != 3) return invalid_arguments;
299+
300+
const int perm[3] = {0, 2, 1};
301+
return fill_nonblocked(md, perm);
302+
}
303+
297304
status_t fill_Owi4o(memory_desc_t &md) {
298305
if (md.ndims != 3) return invalid_arguments;
299306

@@ -455,6 +462,13 @@ status_t fill_hwio(memory_desc_t &md) {
455462
return fill_nonblocked(md, perm);
456463
}
457464

465+
status_t fill_ohwi(memory_desc_t &md) {
466+
if (md.ndims != 4) return invalid_arguments;
467+
468+
const int perm[4] = {0, 2, 3, 1};
469+
return fill_nonblocked(md, perm);
470+
}
471+
458472
status_t fill_iohw(memory_desc_t &md) {
459473
if (md.ndims != 4) return invalid_arguments;
460474

@@ -469,6 +483,13 @@ status_t fill_dhwio(memory_desc_t &md) {
469483
return fill_nonblocked(md, perm);
470484
}
471485

486+
status_t fill_odhwi(memory_desc_t &md) {
487+
if (md.ndims != 5) return invalid_arguments;
488+
489+
const int perm[5] = {0, 2, 3, 4, 1};
490+
return fill_nonblocked(md, perm);
491+
}
492+
472493
status_t fill_OIhw4i4o(memory_desc_t &md) {
473494
if (md.ndims != 4) return invalid_arguments;
474495

@@ -1331,6 +1352,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
13311352
case io: return fill_io(memory_desc);
13321353
case oiw: return fill_oiw(memory_desc);
13331354
case wio: return fill_wio(memory_desc);
1355+
case owi: return fill_owi(memory_desc);
13341356
case Owi4o: return fill_Owi4o(memory_desc);
13351357
case OIw4i4o: return fill_OIw4i4o(memory_desc);
13361358
case Owi8o: return fill_Owi8o(memory_desc);
@@ -1348,9 +1370,11 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
13481370
case oihw: return fill_oihw(memory_desc);
13491371
case ihwo: return fill_ihwo(memory_desc);
13501372
case hwio: return fill_hwio(memory_desc);
1373+
case ohwi: return fill_ohwi(memory_desc);
13511374
case iohw: return fill_iohw(memory_desc);
13521375
case hwio_s8s8: return fill_hwio(memory_desc);
13531376
case dhwio: return fill_dhwio(memory_desc);
1377+
case odhwi: return fill_odhwi(memory_desc);
13541378
case OIhw4i4o: return fill_OIhw4i4o(memory_desc);
13551379
case OIhw8i8o: return fill_OIhw8i8o(memory_desc);
13561380
case OIhw16i16o: return fill_OIhw16i16o(memory_desc);

src/common/mkldnn_debug.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) {
7171
if (v == mkldnn_io) return "io";
7272
if (v == mkldnn_oiw) return "oiw";
7373
if (v == mkldnn_wio) return "wio";
74+
if (v == mkldnn_owi) return "owi";
7475
if (v == mkldnn_oihw) return "oihw";
7576
if (v == mkldnn_hwio) return "hwio";
77+
if (v == mkldnn_ohwi) return "ohwi";
7678
if (v == mkldnn_ihwo) return "ihwo";
7779
if (v == mkldnn_iohw) return "iohw";
7880
if (v == mkldnn_oidhw) return "oidhw";
7981
if (v == mkldnn_dhwio) return "dhwio";
82+
if (v == mkldnn_odhwi) return "odhwi";
8083
if (v == mkldnn_goiw) return "goiw";
8184
if (v == mkldnn_goihw) return "goihw";
8285
if (v == mkldnn_hwigo) return "hwigo";

src/common/type_helpers.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
113113
io,
114114
oiw,
115115
wio,
116+
owi,
116117
Owi4o,
117118
OIw4i4o,
118119
Owi8o,
@@ -130,9 +131,11 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
130131
oihw,
131132
ihwo,
132133
hwio,
134+
ohwi,
133135
iohw,
134136
hwio_s8s8,
135137
dhwio,
138+
odhwi,
136139
oidhw,
137140
OIdhw4i4o,
138141
Odhwi4o,

src/cpu/cpu_inner_product_pd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ inline memory_format_t src_compatible_fmt(int ndims, memory_format_t wei_fmt) {
5757
return nc;
5858
else if (one_of(wei_fmt, oiw, oihw, oidhw))
5959
return utils::pick(ndims - 3, ncw, nchw, ncdhw);
60-
else if (one_of(wei_fmt, wio, hwio, dhwio))
60+
else if (one_of(wei_fmt, wio, owi, hwio, ohwi, dhwio, odhwi))
6161
return utils::pick(ndims - 3, nwc, nhwc, ndhwc);
6262
else if (one_of(wei_fmt, oIhw8i, oIdhw8i))
6363
return utils::pick(ndims - 4, nChw8c, nCdhw8c);

0 commit comments

Comments
 (0)