Skip to content

Commit 4ea278b

Browse files
author
Fomenko, Evarist M
committed
common: lnorm: validate the number of dims
1 parent 062a29d commit 4ea278b

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/common/layer_normalization.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ status_t lnorm_desc_init(layer_normalization_desc_t *lnorm_desc,
3535
bool args_ok = true && !any_null(lnorm_desc, data_desc)
3636
&& one_of(prop_kind, forward_training, forward_inference,
3737
backward_data, backward)
38+
&& 2 <= data_desc->ndims && data_desc->ndims <= 5
3839
&& IMPLICATION(prop_kind & backward, diff_data_desc != nullptr)
3940
&& (flags & ~(dnnl_use_global_stats | dnnl_use_scaleshift)) == 0;
4041
if (!args_ok) return invalid_arguments;
@@ -79,17 +80,15 @@ status_t lnorm_desc_init(layer_normalization_desc_t *lnorm_desc,
7980
ld.layer_norm_epsilon = epsilon;
8081
ld.flags = flags;
8182

82-
bool consistency = true && utils::one_of(ld.data_desc.ndims, 2, 3, 4, 5);
83-
if (ld.prop_kind == backward_data)
84-
consistency = consistency
85-
&& utils::one_of(ld.diff_data_desc.ndims, 2, 3, 4, 5)
83+
if (ld.prop_kind == backward_data) {
84+
bool consistency = ld.diff_data_desc.ndims == ld.data_desc.ndims
8685
&& array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims,
8786
ld.diff_data_desc.ndims)
8887
&& ld.data_desc.ndims == ld.stat_desc.ndims + 1
8988
&& array_cmp(ld.stat_desc.dims, ld.data_desc.dims,
9089
ld.stat_desc.ndims);
91-
92-
if (!consistency) return invalid_arguments;
90+
if (!consistency) return invalid_arguments;
91+
}
9392

9493
*lnorm_desc = ld;
9594
return success;

0 commit comments

Comments
 (0)