Skip to content

Commit

Permalink
Don't create param in normalization layers instead of create None-val…
Browse files Browse the repository at this point in the history
…ue params.

This makes these layers align better with behavior of Linen layers, and also reduce confusion.

PiperOrigin-RevId: 719498793
IvyZX authored and Flax Authors committed Jan 25, 2025
1 parent d28f03f commit a7b0ede
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
@@ -287,17 +287,19 @@ def __init__(
self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32))
self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32))

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
self.scale = None

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)
self.bias = None

self.num_features = num_features
self.use_running_average = use_running_average
@@ -368,8 +370,8 @@ def __call__(
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
reduction_axes,
feature_axes,
self.dtype,
@@ -454,17 +456,19 @@ def __init__(
):
feature_shape = (num_features,)

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
self.scale = None

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)
self.bias = None

self.num_features = num_features
self.epsilon = epsilon
@@ -503,8 +507,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
self.reduction_axes,
self.feature_axes,
self.dtype,
@@ -582,11 +586,12 @@ def __init__(
):
feature_shape = (num_features,)

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
self.scale = None

self.num_features = num_features
self.epsilon = epsilon
@@ -624,7 +629,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.scale.value if self.scale else None,
None,
self.reduction_axes,
self.feature_axes,
@@ -757,17 +762,19 @@ def __init__(
self.group_size = num_features // num_groups

feature_shape = (num_features,)
self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
self.scale = None

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)
self.bias = None

self.epsilon = epsilon
self.dtype = dtype
@@ -822,8 +829,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
reduction_axes[:-1],
(self.feature_axis,),
self.dtype,

0 comments on commit a7b0ede

Please sign in to comment.