Skip to content

Commit e5d5e05

Browse files
author
Fomenko, Evarist M
committed
gtests: lnorm: check the simplicity of stat_md if tag is any
1 parent 944bcb8 commit e5d5e05

File tree

1 file changed

+92
-2
lines changed

1 file changed

+92
-2
lines changed

tests/gtests/test_layer_normalization.cpp

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
6767
}
6868

6969
void Test() {
70-
p = ::testing::TestWithParam<decltype(p)>::GetParam();
71-
7270
eng = engine(get_test_engine_kind(), 0);
7371
strm = stream(eng);
7472

@@ -106,6 +104,8 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
106104

107105
void Forward(
108106
prop_kind pk, normalization_flags flags = (normalization_flags)0u) {
107+
fwd_iface_test_stat_any(pk, flags);
108+
109109
bool useScaleShift
110110
= (bool)(flags & normalization_flags::use_scale_shift);
111111
bool useGlobalStats
@@ -141,6 +141,8 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
141141

142142
void Backward(
143143
prop_kind pk, normalization_flags flags = (normalization_flags)0u) {
144+
bwd_iface_test_stat_any(pk, flags);
145+
144146
bool useScaleShift
145147
= (bool)(flags & normalization_flags::use_scale_shift);
146148

@@ -452,6 +454,94 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
452454
}
453455
});
454456
}
457+
458+
void fwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
459+
// non stats if inference w/o use global stats
460+
if (pk == prop_kind::forward_inference
461+
&& !(bool)(flags & normalization_flags::use_global_stats))
462+
return;
463+
464+
using tag = memory::format_tag;
465+
466+
tag expect_stat_tag = derive_stat_tag();
467+
if (expect_stat_tag == tag::undef) return; // optimism
468+
469+
memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
470+
memory::desc expect_stat_md(
471+
stat_dims, memory::data_type::f32, expect_stat_tag);
472+
473+
// no stat_md provided at all
474+
{
475+
layer_normalization_forward::primitive_desc fwd_pd(
476+
{pk, *data_d, p.epsilon, flags}, eng);
477+
478+
EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
479+
EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
480+
}
481+
482+
// stat_md with format_tag::any
483+
{
484+
memory::desc any_stat_md(
485+
stat_dims, memory::data_type::f32, tag::any);
486+
layer_normalization_forward::primitive_desc fwd_pd(
487+
{pk, *data_d, any_stat_md, p.epsilon, flags}, eng);
488+
489+
EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
490+
EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
491+
}
492+
}
493+
494+
void bwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
495+
using tag = memory::format_tag;
496+
497+
tag expect_stat_tag = derive_stat_tag();
498+
if (expect_stat_tag == tag::undef) return; // optimism
499+
500+
memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
501+
memory::desc expect_stat_md(
502+
stat_dims, memory::data_type::f32, expect_stat_tag);
503+
504+
layer_normalization_forward::primitive_desc fwd_pd(
505+
{prop_kind::forward_training, *data_d, p.epsilon, flags}, eng);
506+
507+
// no stat_md provided at all
508+
{
509+
layer_normalization_backward::primitive_desc bwd_pd(
510+
{pk, *diff_d, *data_d, p.epsilon, flags}, eng, fwd_pd);
511+
512+
EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
513+
EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
514+
}
515+
516+
// stat_md with format_tag::any
517+
{
518+
memory::desc any_stat_md(
519+
stat_dims, memory::data_type::f32, tag::any);
520+
layer_normalization_backward::primitive_desc bwd_pd(
521+
{pk, *diff_d, *data_d, any_stat_md, p.epsilon, flags}, eng,
522+
fwd_pd);
523+
524+
EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
525+
EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
526+
}
527+
}
528+
529+
private:
530+
memory::format_tag derive_stat_tag() const {
531+
using tag = memory::format_tag;
532+
tag expect_stat_tag = tag::undef;
533+
534+
// TODO: add more cases and test cases
535+
// XXX: currently test only simple cases like `abc`, `acb`. Extend,
536+
// if possible, to blocked formats too.
537+
switch (p.data_tag) {
538+
case tag::abc: expect_stat_tag = tag::ab; break;
539+
case tag::bac: expect_stat_tag = tag::ba; break;
540+
default: break;
541+
}
542+
543+
return expect_stat_tag;
544+
}
455545
};
456546

457547
TEST_P(lnorm_test, TestsLnormF32) {}

0 commit comments

Comments
 (0)