-
Notifications
You must be signed in to change notification settings - Fork 721
Description
Hi team,
I noticed that nnx.Remat
has been deprecated but it is not easily visible from the changelog so it was a bit of a surprise. Trying to replace our usage of nnx.Remat
with nnx.remat
lead to several issues so I wanted to understand what the recommended usage pattern is.
The usage pattern we had and wanted to achieve is that if we have a layer such as the one below where we can dynamically add remat with different policies depending on the use-case.
import flax.nnx as nnx
import jax
import jax.numpy as jnp
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
return self.linear(x)
The thing we tried was to dynamically change __call__
which did not work and is most likely not a good idea.
Would appreciate if you could share the recommended usage pattern and if there is a way to achieve to re-use a module definition and dynamically apply remat to it with multiple policies.
Thank you!