@@ -67,8 +67,6 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
6767 }
6868
6969 void Test () {
70- p = ::testing::TestWithParam<decltype (p)>::GetParam ();
71-
7270 eng = engine (get_test_engine_kind (), 0 );
7371 strm = stream (eng);
7472
@@ -106,6 +104,8 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
106104
107105 void Forward (
108106 prop_kind pk, normalization_flags flags = (normalization_flags)0u) {
107+ fwd_iface_test_stat_any (pk, flags);
108+
109109 bool useScaleShift
110110 = (bool )(flags & normalization_flags::use_scale_shift);
111111 bool useGlobalStats
@@ -141,6 +141,8 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
141141
142142 void Backward (
143143 prop_kind pk, normalization_flags flags = (normalization_flags)0u) {
144+ bwd_iface_test_stat_any (pk, flags);
145+
144146 bool useScaleShift
145147 = (bool )(flags & normalization_flags::use_scale_shift);
146148
@@ -452,6 +454,94 @@ class lnorm_test : public ::testing::TestWithParam<test_lnorm_params_t> {
452454 }
453455 });
454456 }
457+
458+ void fwd_iface_test_stat_any (prop_kind pk, normalization_flags flags) {
459+ // non stats if inference w/o use global stats
460+ if (pk == prop_kind::forward_inference
461+ && !(bool )(flags & normalization_flags::use_global_stats))
462+ return ;
463+
464+ using tag = memory::format_tag;
465+
466+ tag expect_stat_tag = derive_stat_tag ();
467+ if (expect_stat_tag == tag::undef) return ; // optimism
468+
469+ memory::dims stat_dims (p.dims .begin (), p.dims .end () - 1 );
470+ memory::desc expect_stat_md (
471+ stat_dims, memory::data_type::f32 , expect_stat_tag);
472+
473+ // no stat_md provided at all
474+ {
475+ layer_normalization_forward::primitive_desc fwd_pd (
476+ {pk, *data_d, p.epsilon , flags}, eng);
477+
478+ EXPECT_EQ (fwd_pd.mean_desc (), expect_stat_md);
479+ EXPECT_EQ (fwd_pd.variance_desc (), expect_stat_md);
480+ }
481+
482+ // stat_md with format_tag::any
483+ {
484+ memory::desc any_stat_md (
485+ stat_dims, memory::data_type::f32 , tag::any);
486+ layer_normalization_forward::primitive_desc fwd_pd (
487+ {pk, *data_d, any_stat_md, p.epsilon , flags}, eng);
488+
489+ EXPECT_EQ (fwd_pd.mean_desc (), expect_stat_md);
490+ EXPECT_EQ (fwd_pd.variance_desc (), expect_stat_md);
491+ }
492+ }
493+
494+ void bwd_iface_test_stat_any (prop_kind pk, normalization_flags flags) {
495+ using tag = memory::format_tag;
496+
497+ tag expect_stat_tag = derive_stat_tag ();
498+ if (expect_stat_tag == tag::undef) return ; // optimism
499+
500+ memory::dims stat_dims (p.dims .begin (), p.dims .end () - 1 );
501+ memory::desc expect_stat_md (
502+ stat_dims, memory::data_type::f32 , expect_stat_tag);
503+
504+ layer_normalization_forward::primitive_desc fwd_pd (
505+ {prop_kind::forward_training, *data_d, p.epsilon , flags}, eng);
506+
507+ // no stat_md provided at all
508+ {
509+ layer_normalization_backward::primitive_desc bwd_pd (
510+ {pk, *diff_d, *data_d, p.epsilon , flags}, eng, fwd_pd);
511+
512+ EXPECT_EQ (bwd_pd.mean_desc (), expect_stat_md);
513+ EXPECT_EQ (bwd_pd.variance_desc (), expect_stat_md);
514+ }
515+
516+ // stat_md with format_tag::any
517+ {
518+ memory::desc any_stat_md (
519+ stat_dims, memory::data_type::f32 , tag::any);
520+ layer_normalization_backward::primitive_desc bwd_pd (
521+ {pk, *diff_d, *data_d, any_stat_md, p.epsilon , flags}, eng,
522+ fwd_pd);
523+
524+ EXPECT_EQ (bwd_pd.mean_desc (), expect_stat_md);
525+ EXPECT_EQ (bwd_pd.variance_desc (), expect_stat_md);
526+ }
527+ }
528+
529+ private:
530+ memory::format_tag derive_stat_tag () const {
531+ using tag = memory::format_tag;
532+ tag expect_stat_tag = tag::undef;
533+
534+ // TODO: add more cases and test cases
535+ // XXX: currently test only simple cases like `abc`, `acb`. Extend,
536+ // if possible, to blocked formats too.
537+ switch (p.data_tag ) {
538+ case tag::abc: expect_stat_tag = tag::ab; break ;
539+ case tag::bac: expect_stat_tag = tag::ba; break ;
540+ default : break ;
541+ }
542+
543+ return expect_stat_tag;
544+ }
455545};
456546
457547TEST_P (lnorm_test, TestsLnormF32) {}
0 commit comments