Skip to content

Unnecessary check for trace level in nnx.Variable.__setattr__ #4847

@liudangyi

Description

@liudangyi

In the development of a quantization library, we often need to collect some statistics of activations. Sometimes, the collection happens inside a custom_vjp function, as demonstrated below.

class QuantStats(nnx.Variable):
  # __setattr__ = object.__setattr__
  pass


class Model(nnx.Module):
  def __init__(self):
    self.stats = QuantStats({'absmax': jnp.zeros(())})
    self.linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0))

  def __call__(self, x):
    @jax.custom_vjp
    def f(x):
      return fwd(x)[0]

    def fwd(x):
      self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
      return x, ()

    def bwd(_, g):
      return g

    f.defvjp(fwd, bwd)
    return self.linear(f(x))


def loss_fn(model, x):
  out = model(x)
  return jnp.sum(jnp.abs(out))


model = Model()
loss_fn(model, jnp.full((1, 12), 42.))
print(model.stats['absmax'])
nnx.grad(loss_fn)(model, jnp.full((1, 12), 43.))
print(model.stats['absmax'])

Today, running the above code will raise an error like this

/tmp/ipython-input-90-2380234258.py in fwd(x)
     15 
     16     def fwd(x):
---> 17       self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
     18       return x, ()
     19

.../flax/nnx/variablelib.py in __setattr__(self, name, value)
    274       name != 'value' or not self.mutable
    275     ):
--> 276       raise errors.TraceContextError(
    277         f'Cannot mutate {type(self).__name__} from a different trace level'
    278       )

TraceContextError: Cannot mutate QuantStats from a different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

However, jax doesn't require a custom_vjp function to be pure. The above code will work if we uncomment the __setattr__ = object.__setattr__ line.

It seems that NNX is imposing an unnecessary check here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions