Skip to content

Commit 2a5f63b

Browse files
committed
cpu: pooling: fix failde examples
1 parent 27fc2b3 commit 2a5f63b

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

src/cpu/cpu_pooling_pd.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t {
5454
cpu_memory_pd_t ws_pd_;
5555

5656
virtual status_t init() = 0;
57+
58+
virtual status_t set_default_params() {
59+
using namespace memory_format;
60+
if (src_pd_.desc()->format == any)
61+
CHECK(src_pd_.set_format(nchw));
62+
if (dst_pd_.desc()->format == any)
63+
CHECK(dst_pd_.set_format(src_pd_.desc()->format));
64+
return status::success;
65+
}
5766
};
5867

5968
struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
@@ -80,6 +89,15 @@ struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
8089
cpu_memory_pd_t ws_pd_;
8190

8291
virtual status_t init() = 0;
92+
93+
virtual status_t set_default_params() {
94+
using namespace memory_format;
95+
if (diff_dst_pd_.desc()->format == any)
96+
CHECK(diff_dst_pd_.set_format(nchw));
97+
if (diff_src_pd_.desc()->format == any)
98+
CHECK(diff_src_pd_.set_format(diff_dst_pd_.desc()->format));
99+
return status::success;
100+
}
83101
};
84102

85103
}

src/cpu/jit_avx2_pooling.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct jit_avx2_pooling_fwd_t: public cpu_primitive_t {
4444
using namespace utils;
4545
assert(engine()->kind() == engine_kind::cpu);
4646
bool ok = true
47+
&& set_default_params() == status::success
4748
&& one_of(desc()->prop_kind, forward_training,
4849
forward_inference)
4950
&& one_of(desc()->alg_kind, pooling_max, pooling_avg)
@@ -98,6 +99,7 @@ struct jit_avx2_pooling_bwd_t: public cpu_primitive_t {
9899
using namespace utils;
99100
assert(engine()->kind() == engine_kind::cpu);
100101
bool ok = true
102+
&& set_default_params() == status::success
101103
&& one_of(desc()->prop_kind, backward, backward_data)
102104
&& one_of(desc()->alg_kind, pooling_max, pooling_avg)
103105
&& everyone_is(data_type::f32, diff_src_pd()->desc()->data_type,

src/cpu/ref_pooling.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct ref_pooling_fwd_t: public cpu_primitive_t {
4343
using namespace alg_kind;
4444
assert(engine()->kind() == engine_kind::cpu);
4545
bool ok = true
46+
&& set_default_params() == status::success
4647
&& utils::one_of(desc()->prop_kind, forward_training,
4748
forward_inference)
4849
&& utils::one_of(desc()->alg_kind, pooling_max, pooling_avg)
@@ -90,6 +91,7 @@ struct ref_pooling_bwd_t: public cpu_primitive_t {
9091
using namespace alg_kind;
9192
assert(engine()->kind() == engine_kind::cpu);
9293
bool ok = true
94+
&& set_default_params() == status::success
9395
&& utils::one_of(desc()->prop_kind, backward_data)
9496
&& utils::one_of(desc()->alg_kind, pooling_max, pooling_avg)
9597
&& utils::everyone_is(data_type, diff_dst_pd()->desc()->data_type,

0 commit comments

Comments
 (0)