Skip to content

Commit a709edb

Browse files
akharitotprimak
authored andcommitted
benchdnn: minor improvements
- aligned readed/writed paddings values (0 is not default value for padding) - corrected 3d problem check for conv/deconv and pool
1 parent 2a6cb12 commit a709edb

File tree

4 files changed

+49
-46
lines changed

4 files changed

+49
-46
lines changed

tests/benchdnn/conv/conv.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@
3131

3232
namespace conv {
3333

34-
inline bool is_conv_3d(const prb_t *p) {
35-
return p->id > 1;
36-
}
37-
38-
inline bool is_conv_1d(const prb_t *p) {
39-
return !is_conv_3d(p) && p->ih == 1 && p->kh == 1;
40-
}
41-
4234
double get_trust_nz_level(const prb_t *p, data_kind_t kind,
4335
bool final_compare) {
4436
if (!final_compare)
@@ -423,7 +415,7 @@ inline int init_pd(const prb_t *p, mkldnn_convolution_desc_t &cd,
423415
mkldnn_primitive_desc_t &cpd, res_t *r) {
424416
mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
425417

426-
int ndims = is_conv_3d(p) ? 5 : is_conv_1d(p) ? 3 : 4;
418+
int ndims = is_problem_3d(p) ? 5 : is_problem_1d(p) ? 3 : 4;
427419
mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
428420
mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
429421
mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
@@ -439,23 +431,29 @@ inline int init_pd(const prb_t *p, mkldnn_convolution_desc_t &cd,
439431
mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
440432

441433
DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
442-
is_conv_3d(p) ? src_3d_dims : is_conv_1d(p) ? src_1d_dims : src_2d_dims,
443-
p->cfg[SRC].dt, mkldnn_any), WARN);
434+
is_problem_3d(p)
435+
? src_3d_dims
436+
: is_problem_1d(p) ? src_1d_dims : src_2d_dims,
437+
p->cfg[SRC].dt, mkldnn_any),
438+
WARN);
444439

445440
DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
446-
is_conv_3d(p)
447-
? &wei_3d_dims[!p->has_groups]
448-
: is_conv_1d(p)
449-
? &wei_1d_dims[!p->has_groups]
450-
: &wei_2d_dims[!p->has_groups],
451-
p->cfg[WEI].dt, mkldnn_any), WARN);
441+
is_problem_3d(p)
442+
? &wei_3d_dims[!p->has_groups]
443+
: is_problem_1d(p) ? &wei_1d_dims[!p->has_groups]
444+
: &wei_2d_dims[!p->has_groups],
445+
p->cfg[WEI].dt, mkldnn_any),
446+
WARN);
452447

453448
DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt,
454449
mkldnn_any), WARN);
455450

456451
DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
457-
is_conv_3d(p) ? dst_3d_dims : is_conv_1d(p) ? dst_1d_dims : dst_2d_dims,
458-
p->cfg[DST].dt, mkldnn_any), WARN);
452+
is_problem_3d(p)
453+
? dst_3d_dims
454+
: is_problem_1d(p) ? dst_1d_dims : dst_2d_dims,
455+
p->cfg[DST].dt, mkldnn_any),
456+
WARN);
459457

460458
int strides_nd[] = {p->sd, p->sh, p->sw};
461459
int dilates_nd[] = {p->dd, p->dh, p->dw};

tests/benchdnn/conv/conv_aux.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,23 +188,23 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) {
188188
if (canonical || d->has_groups) DPRINT("g%d", d->g);
189189
if (canonical || d->mb != 2) DPRINT("mb%d", d->mb);
190190

191-
const bool half_form = (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
192-
&& d->sh == d->sw && d->ph == d->pw && d->dh == d->dw) && d->id == 1;
191+
const bool half_form
192+
= (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
193+
&& d->sh == d->sw && d->ph == d->pw && d->dh == d->dw)
194+
&& !is_problem_3d(d);
193195

194196
if (!canonical && half_form) {
195197
DPRINT("ic%dih%doc%doh%dkh%d", d->ic, d->ih, d->oc, d->oh, d->kh);
196198
if (d->sh != 1) DPRINT("sh%d", d->sh);
197199
if (d->ph != 0) DPRINT("ph%d", d->ph);
198200
if (d->dh != 0) DPRINT("dh%d", d->dh);
199201
} else {
200-
if( d->id == 1 )
201-
{
202+
if (!is_problem_3d(d)) {
202203
DPRINT("ic%dih%diw%doc%doh%dow%dkh%dkw%d",
203204
d->ic, d->ih, d->iw, d->oc, d->oh, d->ow, d->kh, d->kw);
204205
if (canonical || d->sh != 1 || d->sw != 1)
205206
DPRINT("sh%dsw%d", d->sh, d->sw);
206-
if (canonical || d->ph != 0 || d->pw != 0)
207-
DPRINT("ph%dpw%d", d->ph, d->pw);
207+
DPRINT("ph%dpw%d", d->ph, d->pw);
208208
if (canonical || d->dh != 0 || d->dw != 0)
209209
DPRINT("dh%ddw%d", d->dh, d->dw);
210210
} else {
@@ -213,8 +213,7 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) {
213213
d->kd, d->kh, d->kw);
214214
if (canonical || d->sh != 1 || d->sw != 1 || d->sd != 1)
215215
DPRINT("sd%dsh%dsw%d", d->sd, d->sh, d->sw);
216-
if (canonical || d->ph != 0 || d->pw != 0 || d->pd != 0)
217-
DPRINT("pd%dph%dpw%d", d->pd, d->ph, d->pw);
216+
DPRINT("pd%dph%dpw%d", d->pd, d->ph, d->pw);
218217
if (canonical || d->dh != 0 || d->dw != 0 || d->dd != 0)
219218
DPRINT("dd%ddh%ddw%d", d->dd, d->dh, d->dw);
220219
}

tests/benchdnn/conv/conv_common.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ struct desc_t {
5252

5353
const char *name;
5454
};
55+
56+
inline bool is_problem_3d(const desc_t *p) {
57+
return p->id > 1 || p->kd > 1;
58+
}
59+
60+
inline bool is_problem_1d(const desc_t *p) {
61+
return !is_problem_3d(p) && p->ih == 1 && p->kh == 1;
62+
}
63+
5564
const size_t max_desc_len = 196;
5665
int str2desc(desc_t *desc, const char *str, bool is_deconv);
5766
void desc2str(const desc_t *d, char *buffer, bool canonical = false);

tests/benchdnn/conv/deconv.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@ inline static void swap(int &a, int &b) {
3939
b = temp;
4040
}
4141

42-
inline bool is_deconv_3d(const prb_t *p) {
43-
return p->id > 1;
44-
}
45-
46-
inline bool is_deconv_1d(const prb_t *p) {
47-
return !is_deconv_3d(p) && p->ih == 1 && p->kh == 1;
48-
}
49-
5042
inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
5143
mkldnn::impl::parallel_nd(
5244
p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
@@ -61,7 +53,7 @@ inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr)
6153

6254
inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
6355
mkldnn_primitive_desc_t &dpd, res_t *r) {
64-
int ndims = is_deconv_3d(p) ? 5 : is_deconv_1d(p) ? 3 : 4;
56+
int ndims = is_problem_3d(p) ? 5 : is_problem_1d(p) ? 3 : 4;
6557

6658
mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
6759
mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
@@ -75,20 +67,25 @@ inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
7567
mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
7668
mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
7769
DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
78-
is_deconv_3d(p) ? src_3d_dims : is_deconv_1d(p) ? src_1d_dims : src_2d_dims,
79-
p->cfg[SRC].dt, mkldnn_any),
70+
is_problem_3d(p)
71+
? src_3d_dims
72+
: is_problem_1d(p) ? src_1d_dims : src_2d_dims,
73+
p->cfg[SRC].dt, mkldnn_any),
8074
WARN);
8175
DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
82-
is_deconv_3d(p)
83-
? &wei_3d_dims[!p->has_groups]
84-
: is_deconv_1d(p)
85-
? &wei_1d_dims[!p->has_groups]
86-
: &wei_2d_dims[!p->has_groups],
87-
p->cfg[WEI].dt, mkldnn_any), WARN);
76+
is_problem_3d(p)
77+
? &wei_3d_dims[!p->has_groups]
78+
: is_problem_1d(p) ? &wei_1d_dims[!p->has_groups]
79+
: &wei_2d_dims[!p->has_groups],
80+
p->cfg[WEI].dt, mkldnn_any),
81+
WARN);
8882
DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
8983
DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
90-
is_deconv_3d(p) ? dst_3d_dims : is_deconv_1d(p) ? dst_1d_dims : dst_2d_dims,
91-
p->cfg[DST].dt, mkldnn_any), WARN);
84+
is_problem_3d(p)
85+
? dst_3d_dims
86+
: is_problem_1d(p) ? dst_1d_dims : dst_2d_dims,
87+
p->cfg[DST].dt, mkldnn_any),
88+
WARN);
9289

9390
int strides_nd[] = {p->sd, p->sh, p->sw};
9491
int dilates_nd[] = {p->dd, p->dh, p->dw};

0 commit comments

Comments
 (0)