-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Description
In the development of a quantization library, we often need to collect some statistics of activations. Sometimes, the collection happens inside a custom_vjp function, as demonstrated below.
class QuantStats(nnx.Variable):
# __setattr__ = object.__setattr__
pass
class Model(nnx.Module):
def __init__(self):
self.stats = QuantStats({'absmax': jnp.zeros(())})
self.linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0))
def __call__(self, x):
@jax.custom_vjp
def f(x):
return fwd(x)[0]
def fwd(x):
self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
return x, ()
def bwd(_, g):
return g
f.defvjp(fwd, bwd)
return self.linear(f(x))
def loss_fn(model, x):
out = model(x)
return jnp.sum(jnp.abs(out))
model = Model()
loss_fn(model, jnp.full((1, 12), 42.))
print(model.stats['absmax'])
nnx.grad(loss_fn)(model, jnp.full((1, 12), 43.))
print(model.stats['absmax'])
Today, running the above code will raise an error like this
/tmp/ipython-input-90-2380234258.py in fwd(x)
15
16 def fwd(x):
---> 17 self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
18 return x, ()
19
.../flax/nnx/variablelib.py in __setattr__(self, name, value)
274 name != 'value' or not self.mutable
275 ):
--> 276 raise errors.TraceContextError(
277 f'Cannot mutate {type(self).__name__} from a different trace level'
278 )
TraceContextError: Cannot mutate QuantStats from a different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)
However, jax doesn't require a custom_vjp function to be pure. The above code will work if we uncomment the __setattr__ = object.__setattr__
line.
It seems that NNX is imposing an unnecessary check here.
Metadata
Metadata
Assignees
Labels
No labels