Skip to content

Commit a03f47d

Browse files
committed
c++ api: remove algorithms from primitives to get the same algorithm for forward and backward
1 parent 64eda90 commit a03f47d

File tree

5 files changed

+281
-302
lines changed

5 files changed

+281
-302
lines changed

include/mkldnn.hpp

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ template <> struct handle_traits<c_api::mkldnn_primitive_desc_t> {
252252
/// Memory primitive that describes the data.
253253
struct memory: public primitive {
254254
private:
255-
// TODO: check me please
256255
std::shared_ptr<char> _handle;
257256

258257
public:
@@ -487,6 +486,17 @@ inline c_api::mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
487486
return static_cast<c_api::mkldnn_prop_kind_t>(kind);
488487
}
489488

489+
enum algorithm {
490+
convolution_direct = c_api::mkldnn_convolution_direct,
491+
lrn_across_channels = c_api::mkldnn_lrn_across_channels,
492+
lrn_within_channel = c_api::mkldnn_lrn_within_channel,
493+
pooling_max = c_api::mkldnn_pooling_max,
494+
pooling_avg = c_api::mkldnn_pooling_avg
495+
};
496+
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
497+
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
498+
}
499+
490500
struct reorder : public primitive {
491501
struct primitive_desc : public handle<c_api::mkldnn_primitive_desc_t>{
492502
primitive_desc(const memory::primitive_desc &input,
@@ -697,10 +707,6 @@ struct stream: public handle<c_api::mkldnn_stream_t> {
697707
};
698708

699709
struct convolution_forward: public primitive {
700-
enum algorithm { direct = c_api::mkldnn_convolution_direct };
701-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
702-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
703-
}
704710
struct desc {
705711
c_api::mkldnn_convolution_desc_t data;
706712
desc(prop_kind aprop_kind, algorithm aalgorithm,
@@ -741,7 +747,6 @@ struct convolution_forward: public primitive {
741747
"could not create a convolution forward descriptor");
742748
}
743749
};
744-
// TODO: replace nullptr -> hint
745750
struct primitive_desc : public handle<c_api::mkldnn_primitive_desc_t>{
746751
primitive_desc(const desc &adesc, const engine &aengine) {
747752
c_api::mkldnn_primitive_desc_t result;
@@ -827,10 +832,6 @@ struct convolution_forward: public primitive {
827832
};
828833

829834
struct convolution_backward_data : public primitive {
830-
enum algorithm { direct = c_api::mkldnn_convolution_direct };
831-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
832-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
833-
}
834835
struct desc {
835836
c_api::mkldnn_convolution_desc_t data;
836837
desc(algorithm aalgorithm,
@@ -877,10 +878,6 @@ struct convolution_backward_data : public primitive {
877878
};
878879

879880
struct convolution_backward_weights : public primitive {
880-
enum algorithm { direct = c_api::mkldnn_convolution_direct };
881-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
882-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
883-
}
884881
struct desc {
885882
c_api::mkldnn_convolution_desc_t data;
886883
desc(algorithm aalgorithm,
@@ -927,10 +924,6 @@ struct convolution_backward_weights : public primitive {
927924
};
928925

929926
struct convolution_backward_bias : public primitive {
930-
enum algorithm { direct = c_api::mkldnn_convolution_direct };
931-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
932-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
933-
}
934927
struct desc {
935928
c_api::mkldnn_convolution_desc_t data;
936929
desc(algorithm aalgorithm,
@@ -966,13 +959,6 @@ struct convolution_backward_bias : public primitive {
966959
};
967960

968961
struct lrn_forward : public primitive {
969-
enum algorithm {
970-
across_channels = c_api::mkldnn_lrn_across_channels,
971-
within_channel = c_api::mkldnn_lrn_within_channel,
972-
};
973-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
974-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
975-
}
976962
struct desc {
977963
c_api::mkldnn_lrn_desc_t data;
978964
desc(prop_kind aprop_kind, algorithm aalgorithm,
@@ -1056,13 +1042,6 @@ struct lrn_forward : public primitive {
10561042
};
10571043

10581044
struct pooling_forward : public primitive {
1059-
enum algorithm {
1060-
max = c_api::mkldnn_pooling_max,
1061-
avg = c_api::mkldnn_pooling_avg
1062-
};
1063-
static c_api::mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
1064-
return static_cast<c_api::mkldnn_alg_kind_t>(aalgorithm);
1065-
}
10661045
struct desc {
10671046
c_api::mkldnn_pooling_desc_t data;
10681047
desc(prop_kind aprop_kind, algorithm aalgorithm,

tests/gtests/test_convolution_format_any.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using fmt = memory::format;
2525
struct conv_any_fmt_test_params {
2626
prop_kind aprop_kind;
2727
const engine::kind engine_kind;
28-
convolution_forward::algorithm aalgorithm;
28+
algorithm aalgorithm;
2929
fmt src_fmt_in;
3030
fmt src_fmt_exp;
3131
fmt weights_fmt_in;
@@ -48,7 +48,7 @@ class convolution_any_fmt_test
4848

4949
ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
5050
ASSERT_EQ(p.aprop_kind, prop_kind::forward);
51-
ASSERT_EQ(p.aalgorithm, convolution_forward::direct);
51+
ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
5252
auto eng = engine(p.engine_kind, 0);
5353
memory::data_type data_type = data_traits<data_t>::data_type;
5454
ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
@@ -113,35 +113,35 @@ TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt)
113113
}
114114
INSTANTIATE_TEST_CASE_P(TestConvolutionAnyFmtForward, conv_any_fmt_test_float,
115115
::testing::Values(conv_any_fmt_test_params_float{ prop_kind::forward,
116-
engine::kind::cpu, convolution_forward::direct, fmt::any, fmt::nchw,
116+
engine::kind::cpu, algorithm::convolution_direct, fmt::any, fmt::nchw,
117117
fmt::any, fmt::oihw, fmt::any, fmt::x, fmt::any, fmt::nchw,
118118
{ 2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1 } }));
119119

120120
INSTANTIATE_TEST_CASE_P(
121121
TestConvolutionAlexnetAnyFmtForwardBlocked, conv_any_fmt_test_float,
122122
::testing::Values(
123123
conv_any_fmt_test_params_float{ prop_kind::forward,
124-
engine::kind::cpu, convolution_forward::direct, fmt::any,
124+
engine::kind::cpu, algorithm::convolution_direct, fmt::any,
125125
fmt::nchw, fmt::any, fmt::Ohwi8o, fmt::any, fmt::x,
126126
fmt::any, fmt::nChw8c,
127127
{ 2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4 } },
128128
conv_any_fmt_test_params_float{ prop_kind::forward,
129-
engine::kind::cpu, convolution_forward::direct, fmt::any,
129+
engine::kind::cpu, algorithm::convolution_direct, fmt::any,
130130
fmt::nChw8c, fmt::any, fmt::gOIhw8i8o, fmt::any, fmt::x,
131131
fmt::any, fmt::nChw8c,
132132
{ 2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1 } },
133133
conv_any_fmt_test_params_float{ prop_kind::forward,
134-
engine::kind::cpu, convolution_forward::direct, fmt::any,
134+
engine::kind::cpu, algorithm::convolution_direct, fmt::any,
135135
fmt::nChw8c, fmt::any, fmt::OIhw8i8o, fmt::any, fmt::x,
136136
fmt::any, fmt::nChw8c,
137137
{ 2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
138138
conv_any_fmt_test_params_float{ prop_kind::forward,
139-
engine::kind::cpu, convolution_forward::direct, fmt::any,
139+
engine::kind::cpu, algorithm::convolution_direct, fmt::any,
140140
fmt::nChw8c, fmt::any, fmt::gOIhw8i8o, fmt::any, fmt::x,
141141
fmt::any, fmt::nChw8c,
142142
{ 2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1 } },
143143
conv_any_fmt_test_params_float{ prop_kind::forward,
144-
engine::kind::cpu, convolution_forward::direct, fmt::any,
144+
engine::kind::cpu, algorithm::convolution_direct, fmt::any,
145145
fmt::nChw8c, fmt::any, fmt::gOIhw8i8o, fmt::any, fmt::x,
146146
fmt::any, fmt::nChw8c, { 2, 2, 384, 13, 13, 256, 13, 13,
147147
3, 3, 1, 1, 1, 1 } }));

0 commit comments

Comments
 (0)