@@ -35,6 +35,7 @@ status_t lnorm_desc_init(layer_normalization_desc_t *lnorm_desc,
3535 bool args_ok = true && !any_null (lnorm_desc, data_desc)
3636 && one_of (prop_kind, forward_training, forward_inference,
3737 backward_data, backward)
38+ && 2 <= data_desc->ndims && data_desc->ndims <= 5
3839 && IMPLICATION (prop_kind & backward, diff_data_desc != nullptr )
3940 && (flags & ~(dnnl_use_global_stats | dnnl_use_scaleshift)) == 0 ;
4041 if (!args_ok) return invalid_arguments;
@@ -79,17 +80,15 @@ status_t lnorm_desc_init(layer_normalization_desc_t *lnorm_desc,
7980 ld.layer_norm_epsilon = epsilon;
8081 ld.flags = flags;
8182
82- bool consistency = true && utils::one_of (ld.data_desc .ndims , 2 , 3 , 4 , 5 );
83- if (ld.prop_kind == backward_data)
84- consistency = consistency
85- && utils::one_of (ld.diff_data_desc .ndims , 2 , 3 , 4 , 5 )
83+ if (ld.prop_kind == backward_data) {
84+ bool consistency = ld.diff_data_desc .ndims == ld.data_desc .ndims
8685 && array_cmp (ld.diff_data_desc .dims , ld.data_desc .dims ,
8786 ld.diff_data_desc .ndims )
8887 && ld.data_desc .ndims == ld.stat_desc .ndims + 1
8988 && array_cmp (ld.stat_desc .dims , ld.data_desc .dims ,
9089 ld.stat_desc .ndims );
91-
92- if (!consistency) return invalid_arguments;
90+ if (!consistency) return invalid_arguments;
91+ }
9392
9493 *lnorm_desc = ld;
9594 return success;
0 commit comments