Skip to content

nnx.while_loop does not work with Scan over layers or nnx.remat #4819

@gliese581gg

Description

@gliese581gg

Hi. I'm trying to apply scan over layers approach ( https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers ) and nnx.remat to my Transformer module to save the compilation time and memory consumption.

However, auto-regressive decoding part does not work after I apply either scan over layers or nnx.remat to the Transformer.

Here is the simplified script to reproduce the issue.

from flax import nnx
import jax


class Transformer(nnx.Module):
    def __init__(self, num_blocks=10, use_remat=True, use_scan=True, rngs=nnx.Rngs(1)):
        print("Remat", use_remat, "Scan", use_scan)
        self.use_remat = use_remat
        self.use_scan = use_scan
        if not use_scan:
            self.blocks = [nnx.Linear(20, 20, rngs=rngs) for _ in range(num_blocks)]
            return

        @nnx.split_rngs(splits=num_blocks)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_block(rngs: nnx.Rngs):
            return nnx.Linear(20, 20, rngs=rngs)

        self.blocks = create_block(rngs)

    def __call__(self, x):
        if self.use_scan:
            return self._call_scan(x)
        else:
            return self._call_no_scan(x)

    def _call_no_scan(self, x):
        for block in self.blocks:
            if self.use_remat:
                x = run_block_w_checkpointing(block, x)
            else:
                x = block(x)
        return x

    def _call_scan(self, x):
        def run_block(x, block):
            x = block(x)
            return x, None

        if self.use_remat:
            run_block = nnx.remat(
                run_block,
                policy=jax.checkpoint_policies.nothing_saveable,
            )

        run_block = nnx.scan(
            run_block,
            in_axes=(nnx.Carry, 0),
            out_axes=(nnx.Carry, 0),
        )

        x, _ = run_block(x, self.blocks)

        return x


@nnx.remat(policy=jax.checkpoint_policies.nothing_saveable)
def run_block_w_checkpointing(
    block,
    x,
):
    return block(x)


class LLM(nnx.Module):
    def __init__(self, use_remat, use_scan, rngs):
        self.text_encoder = nnx.Linear(10, 20, rngs=rngs)
        self.transformer = Transformer(8, use_remat, use_scan, rngs=rngs)
        self.text_decoder = nnx.Linear(20, 10, rngs=rngs)

    def __call__(self, x):
        return self.text_decoder(self.transformer(self.text_encoder(x)))

    def autoregressive_decode(self, x, decode_len: int = 10):

        def decode_step(carry):
            idx, x = carry
            x = self.text_encoder(x)
            x = self.transformer(x)
            x = self.text_decoder(x)
            return idx + 1, x

        out = nnx.while_loop(
            lambda input: input[0] < decode_len,
            decode_step,
            (0, x),
        )

        return out[-1]


@nnx.jit(static_argnames=["decode_len"])
def decode(model: LLM, x: jax.Array, decode_len: int):
    return model.autoregressive_decode(x, decode_len)


@nnx.jit
def normal_run(model, x):
    return model(x)


if __name__ == "__main__":
    import argparse

    def parse_bool(x):
        if x in ("true", "1", 1, "True"):
            return True
        return False

    parser = argparse.ArgumentParser()
    parser.add_argument("--use_remat", default=True, type=parse_bool)
    parser.add_argument("--use_scan", default=True, type=parse_bool)
    args = parser.parse_args()

    llm = LLM(args.use_remat, args.use_scan, nnx.Rngs(2025))
    x_key = jax.random.key(0)
    x = jax.random.normal(x_key, (1, 10))
    y = normal_run(llm, x)
    print("NORMAL RUN SUCCESS")
    y_ar = decode(llm, x, 10)
    print("DECODE RUN SUCCESS")

I tried all 4 options (use_remat=1/0, use_scan=1/0). normal_run works for all combinations, but decode do not work when either use_remat or use_scan is 1.

When use_remat is True, the error message is

Traceback (most recent call last):
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 120, in <module>
    y_ar = decode(llm, x, 10)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 431, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
    pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 234, in to_tree
    check_consistent_aliasing(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 57, in check_consistent_aliasing
    value._check_valid_context(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/object.py", line 302, in _check_valid_context
    raise errors.TraceContextError(error_msg())
flax.errors.TraceContextError: Trying to extract graph node from different trace level, got Linear( # Param: 420 (1.7 KB)
  bias=Param( # 20 (80 B)
    value=Array(shape=(20,), dtype=dtype('float32'))
  ),
  bias_init=<function zeros at 0x75bf4f5dcee0>,
  dot_general=<function dot_general at 0x75bf4fdfc670>,
  dtype=None,
  in_features=20,
  kernel=Param( # 400 (1.6 KB)
    value=Array(shape=(20, 20), dtype=dtype('float32'))
  ),
  kernel_init=<function variance_scaling.<locals>.init at 0x75bf4e85e0e0>,
  out_features=20,
  param_dtype=float32,
  precision=None,
  promote_dtype=<function promote_dtype at 0x75bf4e85dbd0>,
  use_bias=True
) (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

When use_scan is True (regardless of use_remat), the error message is

Traceback (most recent call last):
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 120, in <module>
    y_ar = decode(llm, x, 10)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 431, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 126, in __call__
    out = self.f(*args, **kwargs)
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 94, in decode
    return model.autoregressive_decode(x, decode_len)
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 83, in autoregressive_decode
    out = nnx.while_loop(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
    return f(*args, **kwargs)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1439, in while_loop
    pure_out = jax.lax.while_loop(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
    return f(*args, **kwargs)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1380, in __call__
    out = self.f(val)
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 79, in decode_step
    x = self.transformer(x)
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 23, in __call__
    return self._call_scan(x)
  File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 52, in _call_scan
    x, _ = run_block(x, self.blocks)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
    return f(*args, **kwargs)
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1213, in scan_wrapper
    pure_args: tuple = extract.to_tree(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 234, in to_tree
    check_consistent_aliasing(
  File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 62, in check_consistent_aliasing
    raise ValueError(
ValueError: Cannot extract graph node from different trace level, got Param( # 160 (640 B)
  value=Traced<float32[8,20]>with<DynamicJaxprTrace>
)

I don't know what I am doing wrong. How can I make the decoding works with remat and scan?

I am using jaxlib==0.6.2, jax==0.6.2, flax==0.10.7

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions