Skip to content

Examples how to use fori_loop for gradient accumulation & more clear exceptions #5113

@qGentry

Description

@qGentry

Hey folks,

It would be nice to have some examples in Flax NNX documentation on how to do gradient accumulation with nnx.fori_loop.

I've tried lots of different combination of "how to pass arguments, what I can add to closure and what I can't" myself and haven't succeed.

My train step looks something like this:

# TrainSession is NamedTuple with 'model: nnx.Module' and 'optimizer: nnx.Optimizer' fields

def get_training_step(
    forward_pass_cfg: ForwardPassConfigBase,
    batch_sharding: PyTree,
    num_minibatches: int = 1,
) -> Callable[
    [PyTree, TrainSession],
    tuple[DeviceArray, PyTree],
]:
    forward_pass_fn = get_forward_pass(forward_pass_cfg)

    @nnx.jit(donate_argnums=(1,))
    def training_step(
        batch: PyTree,
        train_session: TrainSession,
    ) -> tuple[DeviceArray, PyTree]:

        def _loop_body(
            minibatch_idx,
            carry: tuple[nnx.State, DeviceArray, TrainSession],
        ) -> tuple[nnx.State, DeviceArray, TrainSession]:
            g_accum, loss_accum, train_session = carry
            minibatch = get_sharded_minibatch(
                batch=batch,
                batch_sharding=batch_sharding,
                minibatch_idx=minibatch_idx,
                num_minibatches=num_minibatches,
            )

            def _loss_fn(model) -> tuple[DeviceArray, PyTree]:
                return forward_pass_fn(
                    model=model,
                    batch=minibatch,
                    step=train_session.optimizer.step.value,
                    config=forward_pass_cfg,
                )

            (loss, outputs), grads = nnx.value_and_grad(_loss_fn, has_aux=True)(
                train_session.model
            )
            g_accum = jax.tree.map(
                lambda gm, g: gm + g, g_accum, grads
            )
            loss_accum = loss_accum + loss
            return g_accum, loss_accum, train_session

        g_accum = jax.tree.map(jnp.zeros_like, nnx.state(train_session.model))
        g_accum, loss_accum, train_session = nnx.fori_loop(
            lower=0, 
            upper=num_minibatches,
            body_fun=_loop_body,
            init_val=(g_accum, 0.0, train_session)
        )

        g_accum = jax.tree.map(lambda g: g / num_minibatches, g_accum)
        loss_accum = loss_accum / num_minibatches

        train_session.optimizer.update(train_session.model, g_accum)
        return loss_accum, {}

    return training_step

And I'm getting the following error:

ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from `outer_index` field, you might have modified reference structure within the body function, which is not allowed.

But I'm not sure I what does it mean and how can I fix it.

It would be nice if flax provided examples on documentation on how to approach it properly, what can be done, what can't be done, how to work with closures, etc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions