File tree Expand file tree Collapse file tree 3 files changed +22
-0
lines changed Expand file tree Collapse file tree 3 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -54,6 +54,15 @@ struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t {
5454 cpu_memory_pd_t ws_pd_;
5555
5656 virtual status_t init () = 0;
57+
58+ virtual status_t set_default_params () {
59+ using namespace memory_format ;
60+ if (src_pd_.desc ()->format == any)
61+ CHECK (src_pd_.set_format (nchw));
62+ if (dst_pd_.desc ()->format == any)
63+ CHECK (dst_pd_.set_format (src_pd_.desc ()->format ));
64+ return status::success;
65+ }
5766};
5867
5968struct cpu_pooling_bwd_pd_t : public pooling_bwd_pd_t {
@@ -80,6 +89,15 @@ struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
8089 cpu_memory_pd_t ws_pd_;
8190
8291 virtual status_t init () = 0;
92+
93+ virtual status_t set_default_params () {
94+ using namespace memory_format ;
95+ if (diff_dst_pd_.desc ()->format == any)
96+ CHECK (diff_dst_pd_.set_format (nchw));
97+ if (diff_src_pd_.desc ()->format == any)
98+ CHECK (diff_src_pd_.set_format (diff_dst_pd_.desc ()->format ));
99+ return status::success;
100+ }
83101};
84102
85103}
Original file line number Diff line number Diff line change @@ -44,6 +44,7 @@ struct jit_avx2_pooling_fwd_t: public cpu_primitive_t {
4444 using namespace utils ;
4545 assert (engine ()->kind () == engine_kind::cpu);
4646 bool ok = true
47+ && set_default_params () == status::success
4748 && one_of (desc ()->prop_kind , forward_training,
4849 forward_inference)
4950 && one_of (desc ()->alg_kind , pooling_max, pooling_avg)
@@ -98,6 +99,7 @@ struct jit_avx2_pooling_bwd_t: public cpu_primitive_t {
9899 using namespace utils ;
99100 assert (engine ()->kind () == engine_kind::cpu);
100101 bool ok = true
102+ && set_default_params () == status::success
101103 && one_of (desc ()->prop_kind , backward, backward_data)
102104 && one_of (desc ()->alg_kind , pooling_max, pooling_avg)
103105 && everyone_is (data_type::f32 , diff_src_pd ()->desc ()->data_type ,
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ struct ref_pooling_fwd_t: public cpu_primitive_t {
4343 using namespace alg_kind ;
4444 assert (engine ()->kind () == engine_kind::cpu);
4545 bool ok = true
46+ && set_default_params () == status::success
4647 && utils::one_of (desc ()->prop_kind , forward_training,
4748 forward_inference)
4849 && utils::one_of (desc ()->alg_kind , pooling_max, pooling_avg)
@@ -90,6 +91,7 @@ struct ref_pooling_bwd_t: public cpu_primitive_t {
9091 using namespace alg_kind ;
9192 assert (engine ()->kind () == engine_kind::cpu);
9293 bool ok = true
94+ && set_default_params () == status::success
9395 && utils::one_of (desc ()->prop_kind , backward_data)
9496 && utils::one_of (desc ()->alg_kind , pooling_max, pooling_avg)
9597 && utils::everyone_is (data_type, diff_dst_pd ()->desc ()->data_type ,
You can’t perform that action at this time.
0 commit comments