Skip to content

Deprecation of nnx.Remat and usage of nnx.remat #4778

@lraedel

Description

@lraedel

Hi team,

I noticed that nnx.Remat has been deprecated but it is not easily visible from the changelog so it was a bit of a surprise. Trying to replace our usage of nnx.Remat with nnx.remat lead to several issues so I wanted to understand what the recommended usage pattern is.

The usage pattern we had and wanted to achieve is that if we have a layer such as the one below where we can dynamically add remat with different policies depending on the use-case.

import flax.nnx as nnx
import jax
import jax.numpy as jnp
class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, rngs: nnx.Rngs):
        self.linear = nnx.Linear(din, dout, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.linear(x)

The thing we tried was to dynamically change __call__ which did not work and is most likely not a good idea.

Would appreciate if you could share the recommended usage pattern and if there is a way to achieve to re-use a module definition and dynamically apply remat to it with multiple policies.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions