From a7b0edebf644b75c7d0d7bd096d6860454308c0b Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Fri, 24 Jan 2025 18:12:45 -0800 Subject: [PATCH] Don't create param in normalization layers instead of create None-value params. This makes these layers align better with behavior of Linen layers, and also reduce confusion. PiperOrigin-RevId: 719498793 --- flax/nnx/nn/normalization.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 921a030f2..72c6450cf 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -287,17 +287,19 @@ def __init__( self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32)) self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32)) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: - self.scale = nnx.Param(None) + self.scale = None + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) else: - self.bias = nnx.Param(None) + self.bias = None self.num_features = num_features self.use_running_average = use_running_average @@ -368,8 +370,8 @@ def __call__( x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, reduction_axes, feature_axes, self.dtype, @@ -454,17 +456,19 @@ def __init__( ): feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: - self.scale = nnx.Param(None) + self.scale = None + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) else: - self.bias = nnx.Param(None) + self.bias = None self.num_features = num_features self.epsilon = epsilon @@ -503,8 +507,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, self.reduction_axes, self.feature_axes, self.dtype, @@ -582,11 +586,12 @@ def __init__( ): feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: - self.scale = nnx.Param(None) + self.scale = None self.num_features = num_features self.epsilon = epsilon @@ -624,7 +629,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, + self.scale.value if self.scale else None, None, self.reduction_axes, self.feature_axes, @@ -757,17 +762,19 @@ def __init__( self.group_size = num_features // num_groups feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: - self.scale = nnx.Param(None) + self.scale = None + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) else: - self.bias = nnx.Param(None) + self.bias = None self.epsilon = epsilon self.dtype = dtype @@ -822,8 +829,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, reduction_axes[:-1], (self.feature_axis,), self.dtype,