We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e98e291 commit d89bf4bCopy full SHA for d89bf4b
src/common/format_traits.hpp
@@ -119,6 +119,7 @@ DECL_TRAITS(io, wei, _, 2, 0);
119
/* wei: 3D */
120
DECL_TRAITS(oiw, wei, _, 3, 1);
121
DECL_TRAITS(wio, wei, _, 3, 1);
122
+DECL_TRAITS(owi, wei, _, 3, 1);
123
DECL_TRAITS(Owi4o, wei, _4o, 3, 1);
124
DECL_TRAITS(OIw4i4o, wei, _4i4o, 3, 1);
125
DECL_TRAITS(Owi8o, wei, _8o, 3, 1);
@@ -140,6 +141,7 @@ DECL_TRAITS(OIw4i16o4i_s8s8, wei, _4i16o4i_s8s8, 3, 1);
140
141
DECL_TRAITS(oihw, wei, _, 4, 2);
142
DECL_TRAITS(ihwo, wei, _, 4, 2);
143
DECL_TRAITS(hwio, wei, _, 4, 2);
144
+DECL_TRAITS(ohwi, wei, _, 4, 2);
145
DECL_TRAITS(iohw, wei, _, 4, 2);
146
DECL_TRAITS(hwio_s8s8, wei, _, 4, 2);
147
DECL_TRAITS(oIhw8i, wei, _8i, 4, 2);
@@ -164,6 +166,7 @@ DECL_TRAITS(Ohwi16o, wei, _16o, 4, 2);
164
166
165
167
/* wei: 5D */
168
DECL_TRAITS(dhwio, wei, _, 5, 3);
169
+DECL_TRAITS(odhwi, wei, _, 5, 3);
170
DECL_TRAITS(oidhw, wei, _, 5, 3);
171
DECL_TRAITS(OIdhw4i4o, wei, _4i4o, 5, 3);
172
DECL_TRAITS(Odhwi4o, wei, _4o, 5, 3);
src/common/memory_desc_wrapper.cpp
@@ -294,6 +294,13 @@ status_t fill_wio(memory_desc_t &md) {
294
return fill_nonblocked(md, perm);
295
}
296
297
+status_t fill_owi(memory_desc_t &md) {
298
+ if (md.ndims != 3) return invalid_arguments;
299
+
300
+ const int perm[3] = {0, 2, 1};
301
+ return fill_nonblocked(md, perm);
302
+}
303
304
status_t fill_Owi4o(memory_desc_t &md) {
305
if (md.ndims != 3) return invalid_arguments;
306
@@ -455,6 +462,13 @@ status_t fill_hwio(memory_desc_t &md) {
455
462
456
463
457
464
465
+status_t fill_ohwi(memory_desc_t &md) {
466
+ if (md.ndims != 4) return invalid_arguments;
467
468
+ const int perm[4] = {0, 2, 3, 1};
469
470
471
458
472
status_t fill_iohw(memory_desc_t &md) {
459
473
if (md.ndims != 4) return invalid_arguments;
460
474
@@ -469,6 +483,13 @@ status_t fill_dhwio(memory_desc_t &md) {
483
484
485
486
+status_t fill_odhwi(memory_desc_t &md) {
487
+ if (md.ndims != 5) return invalid_arguments;
488
489
+ const int perm[5] = {0, 2, 3, 4, 1};
490
491
492
493
status_t fill_OIhw4i4o(memory_desc_t &md) {
494
495
@@ -1331,6 +1352,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
1331
1352
case io: return fill_io(memory_desc);
1332
1353
case oiw: return fill_oiw(memory_desc);
1333
1354
case wio: return fill_wio(memory_desc);
1355
+ case owi: return fill_owi(memory_desc);
1334
1356
case Owi4o: return fill_Owi4o(memory_desc);
1335
1357
case OIw4i4o: return fill_OIw4i4o(memory_desc);
1336
1358
case Owi8o: return fill_Owi8o(memory_desc);
@@ -1348,9 +1370,11 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
1348
1370
case oihw: return fill_oihw(memory_desc);
1349
1371
case ihwo: return fill_ihwo(memory_desc);
1350
1372
case hwio: return fill_hwio(memory_desc);
1373
+ case ohwi: return fill_ohwi(memory_desc);
1351
1374
case iohw: return fill_iohw(memory_desc);
1375
case hwio_s8s8: return fill_hwio(memory_desc);
1376
case dhwio: return fill_dhwio(memory_desc);
1377
+ case odhwi: return fill_odhwi(memory_desc);
1378
case OIhw4i4o: return fill_OIhw4i4o(memory_desc);
1379
case OIhw8i8o: return fill_OIhw8i8o(memory_desc);
1380
case OIhw16i16o: return fill_OIhw16i16o(memory_desc);
src/common/mkldnn_debug.cpp
@@ -71,12 +71,15 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) {
71
if (v == mkldnn_io) return "io";
72
if (v == mkldnn_oiw) return "oiw";
73
if (v == mkldnn_wio) return "wio";
74
+ if (v == mkldnn_owi) return "owi";
75
if (v == mkldnn_oihw) return "oihw";
76
if (v == mkldnn_hwio) return "hwio";
77
+ if (v == mkldnn_ohwi) return "ohwi";
78
if (v == mkldnn_ihwo) return "ihwo";
79
if (v == mkldnn_iohw) return "iohw";
80
if (v == mkldnn_oidhw) return "oidhw";
81
if (v == mkldnn_dhwio) return "dhwio";
82
+ if (v == mkldnn_odhwi) return "odhwi";
83
if (v == mkldnn_goiw) return "goiw";
84
if (v == mkldnn_goihw) return "goihw";
85
if (v == mkldnn_hwigo) return "hwigo";
src/common/type_helpers.hpp
@@ -113,6 +113,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
113
io,
114
oiw,
115
wio,
116
+ owi,
117
Owi4o,
118
OIw4i4o,
Owi8o,
@@ -130,9 +131,11 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
130
131
oihw,
132
ihwo,
133
hwio,
134
+ ohwi,
135
iohw,
136
hwio_s8s8,
137
dhwio,
138
+ odhwi,
139
oidhw,
OIdhw4i4o,
Odhwi4o,
src/cpu/cpu_inner_product_pd.hpp
@@ -57,7 +57,7 @@ inline memory_format_t src_compatible_fmt(int ndims, memory_format_t wei_fmt) {
57
return nc;
58
else if (one_of(wei_fmt, oiw, oihw, oidhw))
59
return utils::pick(ndims - 3, ncw, nchw, ncdhw);
60
- else if (one_of(wei_fmt, wio, hwio, dhwio))
+ else if (one_of(wei_fmt, wio, owi, hwio, ohwi, dhwio, odhwi))
61
return utils::pick(ndims - 3, nwc, nhwc, ndhwc);
62
else if (one_of(wei_fmt, oIhw8i, oIdhw8i))
63
return utils::pick(ndims - 4, nChw8c, nCdhw8c);
0 commit comments