-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
Hi folks, me again.
I've noticed that using nnx.jit+nnx.grad or jax.jit+nnx.grad consumes significantly more memory than pure functional jax.jit+jax.grad. Here's a repro script that outputs:
root@computeinstance-e00p9411jfafcfq06n:/papyrax# /bin/python /papyrax/test_memory_usage.py
NNX Train Step Memory Usage:
tmp: 5.26 GB, arguments: 7.50 GB, total: 12.76 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428392410278
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169857740867883
Step 40, loss: 0.00022612689645029604
JAX + NNX Mix Train Step Memory Usage:
tmp: 5.26 GB, arguments: 7.50 GB, total: 12.76 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428974486887
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169859196059406
Step 40, loss: 0.00022612689645029604
Pure JAX Train Step Memory Usage:
tmp: 2.76 GB, arguments: 7.50 GB, total: 10.26 GB, host tmp 0.00 KB
Step 0, loss: 0.0007948428974486887
Step 10, loss: 0.0003609635168686509
Step 20, loss: 0.00025237887166440487
Step 30, loss: 0.00023169859196059406
Step 40, loss: 0.00022612689645029604
from typing import NamedTuple
from flax import nnx
import jax
import jax.numpy as jnp
import optax
def format_bytes(b: float) -> str:
if b < 1024**2:
return f"{b / 1024:.2f} KB"
elif b < 1024**3:
return f"{b / 1024**2:.2f} MB"
else:
return f"{b / 1024**3:.2f} GB"
def compile_and_report_memory(fn, *args, **kwargs):
compiled_fn = fn.lower(*args, **kwargs).compile()
memory_analysis = compiled_fn.memory_analysis()
print(
f"tmp: {format_bytes(memory_analysis.temp_size_in_bytes)}, "
f"arguments: {format_bytes(memory_analysis.argument_size_in_bytes)}, " # noqa: E501
f"total: {format_bytes(memory_analysis.temp_size_in_bytes + memory_analysis.argument_size_in_bytes)}, " # noqa: E501
f"host tmp {format_bytes(memory_analysis.host_temp_size_in_bytes)}"
)
return compiled_fn
class TrainSession(NamedTuple):
model: nnx.Module
optimizer: nnx.Optimizer
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear(x)
x = jax.nn.relu(x)
return x
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
self.features = features
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return self._create_block(rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def _create_block(self, rngs: nnx.Rngs) -> Block:
return Block(self.features, self.features, rngs=rngs)
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
@nnx.remat(prevent_cse=False)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
@nnx.jit(donate_argnums=(0,))
def train_step_nnx(train_session: TrainSession, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(train_session.model)
train_session.optimizer.update(train_session.model, grads)
return loss
DIM = 8192
N_LAYERS = 10
def get_train_session():
model = MLP(DIM, N_LAYERS, rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
train_state = TrainSession(
model=model,
optimizer=optimizer,
)
return train_state
def get_jax_nnx_mix_jit_train_step(train_session: TrainSession):
train_session_gdef = nnx.graphdef(train_session)
@jax.jit(donate_argnums=(0,))
def train_step(train_session_state, x, y):
train_session: TrainSession = nnx.merge(train_session_gdef, train_session_state)
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(train_session.model)
train_session.optimizer.update(train_session.model, grads)
return loss, nnx.state(train_session)
return train_step
def get_pure_jax_jit_train_step(train_session: TrainSession):
train_session_gdef = nnx.graphdef(train_session)
model_gdef = nnx.graphdef(train_session.model)
@jax.jit(donate_argnums=(0,))
def train_step(train_session_state, x, y):
train_session: TrainSession = nnx.merge(train_session_gdef, train_session_state)
model_state = nnx.state(train_session.model)
def loss_fn(model_state):
model = nnx.merge(model_gdef, model_state)
y_pred = model(x)
return ((y_pred - y) ** 2).mean()
loss, grads = jax.value_and_grad(loss_fn)(model_state)
train_session.optimizer.update(train_session.model, grads)
return loss, nnx.state(train_session)
return train_step
inputs = jnp.ones((16, DIM))
targets = jnp.ones((16, DIM)) * 0.03
train_session = get_train_session()
print("NNX Train Step Memory Usage:")
nnx_compiled = compile_and_report_memory(train_step_nnx, train_session, inputs, targets)
for i in range(50):
loss = nnx_compiled(train_session, inputs, targets)
if i % 10 == 0:
print(f"Step {i}, loss: {loss}")
train_session = get_train_session()
train_step_jax = get_jax_nnx_mix_jit_train_step(train_session)
print("JAX + NNX Mix Train Step Memory Usage:")
jax_compiled = compile_and_report_memory(train_step_jax, nnx.state(train_session), inputs, targets)
train_state = nnx.state(train_session)
for i in range(50):
loss, train_state = jax_compiled(train_state, inputs, targets)
if i % 10 == 0:
print(f"Step {i}, loss: {loss}")
train_session = get_train_session()
train_step_jax = get_pure_jax_jit_train_step(train_session)
print("Pure JAX Train Step Memory Usage:")
jax_compiled = compile_and_report_memory(train_step_jax, nnx.state(train_session), inputs, targets)
train_state = nnx.state(train_session)
for i in range(50):
loss, train_state = jax_compiled(train_state, inputs, targets)
if i % 10 == 0:
print(f"Step {i}, loss: {loss}")Metadata
Metadata
Assignees
Labels
No labels