Creating a running average with self.variable #1005
-
|
Hi all, I tried to create a running average for a module. However, I am getting this error message: Could someone point me to what I am doing wrong? Thanks! import jax.numpy as jnp
from jax import random
import flax.linen as nn
class Net(nn.Module):
@nn.compact
def __call__(self, x):
is_initialized = self.has_variable('moving_stats', 'mean')
mean = self.variable('moving_stats', 'mean', jnp.zeros, [3])
if is_initialized:
mean.value = 0.9 * mean.value + 0.1 * x
return mean.value
key = random.PRNGKey(0)
x = random.normal(key, shape=(2, 3))
net = Net()
params = net.init(key, x)
y = net.apply(params, x) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
By default all variables are immutable during |
Beta Was this translation helpful? Give feedback.
By default all variables are immutable during
apply. This is to avoid accidental side effects in otherwise stateless code.Here you should use
y, new_state = net.apply(params, x, mutable=['moving_stats']).The
new_statewill be a dict containing 'moving_stats' with the updated batch statistics.