-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Open
Copy link
Description
Hey folks,
It would be nice to have some examples in Flax NNX documentation on how to do gradient accumulation with nnx.fori_loop.
I've tried lots of different combination of "how to pass arguments, what I can add to closure and what I can't" myself and haven't succeed.
My train step looks something like this:
# TrainSession is NamedTuple with 'model: nnx.Module' and 'optimizer: nnx.Optimizer' fields
def get_training_step(
forward_pass_cfg: ForwardPassConfigBase,
batch_sharding: PyTree,
num_minibatches: int = 1,
) -> Callable[
[PyTree, TrainSession],
tuple[DeviceArray, PyTree],
]:
forward_pass_fn = get_forward_pass(forward_pass_cfg)
@nnx.jit(donate_argnums=(1,))
def training_step(
batch: PyTree,
train_session: TrainSession,
) -> tuple[DeviceArray, PyTree]:
def _loop_body(
minibatch_idx,
carry: tuple[nnx.State, DeviceArray, TrainSession],
) -> tuple[nnx.State, DeviceArray, TrainSession]:
g_accum, loss_accum, train_session = carry
minibatch = get_sharded_minibatch(
batch=batch,
batch_sharding=batch_sharding,
minibatch_idx=minibatch_idx,
num_minibatches=num_minibatches,
)
def _loss_fn(model) -> tuple[DeviceArray, PyTree]:
return forward_pass_fn(
model=model,
batch=minibatch,
step=train_session.optimizer.step.value,
config=forward_pass_cfg,
)
(loss, outputs), grads = nnx.value_and_grad(_loss_fn, has_aux=True)(
train_session.model
)
g_accum = jax.tree.map(
lambda gm, g: gm + g, g_accum, grads
)
loss_accum = loss_accum + loss
return g_accum, loss_accum, train_session
g_accum = jax.tree.map(jnp.zeros_like, nnx.state(train_session.model))
g_accum, loss_accum, train_session = nnx.fori_loop(
lower=0,
upper=num_minibatches,
body_fun=_loop_body,
init_val=(g_accum, 0.0, train_session)
)
g_accum = jax.tree.map(lambda g: g / num_minibatches, g_accum)
loss_accum = loss_accum / num_minibatches
train_session.optimizer.update(train_session.model, g_accum)
return loss_accum, {}
return training_stepAnd I'm getting the following error:
ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from `outer_index` field, you might have modified reference structure within the body function, which is not allowed.
But I'm not sure I what does it mean and how can I fix it.
It would be nice if flax provided examples on documentation on how to approach it properly, what can be done, what can't be done, how to work with closures, etc.
Metadata
Metadata
Assignees
Labels
No labels