Skip to content

Train step with nnx.jit/jax.jit+nnx.grad consumes twice memory than jax.jit+jax.grad equivalent #5116

@qGentry

Description

@qGentry

Hi folks, me again.

I've noticed that using nnx.jit+nnx.grad or jax.jit+nnx.grad consumes significantly more memory than pure functional jax.jit+jax.grad. Here's a repro script that outputs:

root@computeinstance-e00p9411jfafcfq06n:/papyrax# /bin/python /papyrax/test_memory_usage.py 
NNX Train Step Memory Usage:
tmp: 5.26 GB, arguments: 7.50 GB, total: 12.76 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428392410278
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169857740867883
Step 40, loss: 0.00022612689645029604
JAX + NNX Mix Train Step Memory Usage:
tmp: 5.26 GB, arguments: 7.50 GB, total: 12.76 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428974486887
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169859196059406
Step 40, loss: 0.00022612689645029604
Pure JAX Train Step Memory Usage:
tmp: 2.76 GB, arguments: 7.50 GB, total: 10.26 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428974486887
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169859196059406
Step 40, loss: 0.00022612689645029604
from typing import NamedTuple

from flax import nnx
import jax
import jax.numpy as jnp
import optax



def format_bytes(b: float) -> str:
    if b < 1024**2:
        return f"{b / 1024:.2f} KB"
    elif b < 1024**3:
        return f"{b / 1024**2:.2f} MB"
    else:
        return f"{b / 1024**3:.2f} GB"


def compile_and_report_memory(fn, *args, **kwargs):
    compiled_fn = fn.lower(*args, **kwargs).compile()
    memory_analysis = compiled_fn.memory_analysis()
    print(
        f"tmp: {format_bytes(memory_analysis.temp_size_in_bytes)}, "
        f"arguments: {format_bytes(memory_analysis.argument_size_in_bytes)}, "  # noqa: E501
        f"total: {format_bytes(memory_analysis.temp_size_in_bytes + memory_analysis.argument_size_in_bytes)}, "  # noqa: E501
        f"host tmp {format_bytes(memory_analysis.host_temp_size_in_bytes)}"
    )
    return compiled_fn


class TrainSession(NamedTuple):
    model: nnx.Module
    optimizer: nnx.Optimizer



class Block(nnx.Module):
    def __init__(self, input_dim, features, rngs):
        self.linear = nnx.Linear(input_dim, features, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.linear(x)
        x = jax.nn.relu(x)
        return x


class MLP(nnx.Module):
    def __init__(self, features, num_layers, rngs):
        self.features = features

        @nnx.split_rngs(splits=num_layers)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_block(rngs: nnx.Rngs):
            return self._create_block(rngs)

        self.blocks = create_block(rngs)
        self.num_layers = num_layers

    def _create_block(self, rngs: nnx.Rngs) -> Block:
        return Block(self.features, self.features, rngs=rngs)

    def __call__(self, x):
        @nnx.split_rngs(splits=self.num_layers)
        @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
        @nnx.remat(prevent_cse=False)
        def forward(x, model):
            x = model(x)
            return x

        return forward(x, self.blocks)

@nnx.jit(donate_argnums=(0,))
def train_step_nnx(train_session: TrainSession, x, y):
    def loss_fn(model):
        y_pred = model(x)  # call methods directly
        return ((y_pred - y) ** 2).mean()

    loss, grads = nnx.value_and_grad(loss_fn)(train_session.model)
    train_session.optimizer.update(train_session.model, grads)
    return loss


DIM = 8192
N_LAYERS = 10

def get_train_session():
    model = MLP(DIM, N_LAYERS, rngs=nnx.Rngs(params=0))
    optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

    train_state = TrainSession(
        model=model,
        optimizer=optimizer,
    )
    return train_state

def get_jax_nnx_mix_jit_train_step(train_session: TrainSession):
    train_session_gdef = nnx.graphdef(train_session)
    @jax.jit(donate_argnums=(0,))
    def train_step(train_session_state, x, y):
        train_session: TrainSession = nnx.merge(train_session_gdef, train_session_state)
        def loss_fn(model):
            y_pred = model(x)  # call methods directly
            return ((y_pred - y) ** 2).mean()
        loss, grads = nnx.value_and_grad(loss_fn)(train_session.model)
        train_session.optimizer.update(train_session.model, grads)
        return loss, nnx.state(train_session)
    return train_step


def get_pure_jax_jit_train_step(train_session: TrainSession):
    train_session_gdef = nnx.graphdef(train_session)
    model_gdef = nnx.graphdef(train_session.model)
    @jax.jit(donate_argnums=(0,))
    def train_step(train_session_state, x, y):
        train_session: TrainSession = nnx.merge(train_session_gdef, train_session_state)
        model_state = nnx.state(train_session.model)

        def loss_fn(model_state):
            model = nnx.merge(model_gdef, model_state)
            y_pred = model(x)
            return ((y_pred - y) ** 2).mean()
        loss, grads = jax.value_and_grad(loss_fn)(model_state)

        train_session.optimizer.update(train_session.model, grads)
        return loss, nnx.state(train_session)
    return train_step



inputs = jnp.ones((16, DIM))
targets = jnp.ones((16, DIM)) * 0.03

train_session = get_train_session()
print("NNX Train Step Memory Usage:")
nnx_compiled = compile_and_report_memory(train_step_nnx, train_session, inputs, targets)
for i in range(50):
    loss = nnx_compiled(train_session, inputs, targets)
    if i % 10 == 0:
        print(f"Step {i}, loss: {loss}")


train_session = get_train_session()
train_step_jax = get_jax_nnx_mix_jit_train_step(train_session)
print("JAX + NNX Mix Train Step Memory Usage:")
jax_compiled = compile_and_report_memory(train_step_jax, nnx.state(train_session), inputs, targets)

train_state = nnx.state(train_session)
for i in range(50):
    loss, train_state = jax_compiled(train_state, inputs, targets)
    if i % 10 == 0:
        print(f"Step {i}, loss: {loss}")


train_session = get_train_session()
train_step_jax = get_pure_jax_jit_train_step(train_session)
print("Pure JAX Train Step Memory Usage:")
jax_compiled = compile_and_report_memory(train_step_jax, nnx.state(train_session), inputs, targets)

train_state = nnx.state(train_session)
for i in range(50):
    loss, train_state = jax_compiled(train_state, inputs, targets)
    if i % 10 == 0:
        print(f"Step {i}, loss: {loss}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions