-
Notifications
You must be signed in to change notification settings - Fork 721
Description
Hi. I'm trying to apply scan over layers
approach ( https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers ) and nnx.remat
to my Transformer module to save the compilation time and memory consumption.
However, auto-regressive decoding part does not work after I apply either scan over layers
or nnx.remat
to the Transformer.
Here is the simplified script to reproduce the issue.
from flax import nnx
import jax
class Transformer(nnx.Module):
def __init__(self, num_blocks=10, use_remat=True, use_scan=True, rngs=nnx.Rngs(1)):
print("Remat", use_remat, "Scan", use_scan)
self.use_remat = use_remat
self.use_scan = use_scan
if not use_scan:
self.blocks = [nnx.Linear(20, 20, rngs=rngs) for _ in range(num_blocks)]
return
@nnx.split_rngs(splits=num_blocks)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return nnx.Linear(20, 20, rngs=rngs)
self.blocks = create_block(rngs)
def __call__(self, x):
if self.use_scan:
return self._call_scan(x)
else:
return self._call_no_scan(x)
def _call_no_scan(self, x):
for block in self.blocks:
if self.use_remat:
x = run_block_w_checkpointing(block, x)
else:
x = block(x)
return x
def _call_scan(self, x):
def run_block(x, block):
x = block(x)
return x, None
if self.use_remat:
run_block = nnx.remat(
run_block,
policy=jax.checkpoint_policies.nothing_saveable,
)
run_block = nnx.scan(
run_block,
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
)
x, _ = run_block(x, self.blocks)
return x
@nnx.remat(policy=jax.checkpoint_policies.nothing_saveable)
def run_block_w_checkpointing(
block,
x,
):
return block(x)
class LLM(nnx.Module):
def __init__(self, use_remat, use_scan, rngs):
self.text_encoder = nnx.Linear(10, 20, rngs=rngs)
self.transformer = Transformer(8, use_remat, use_scan, rngs=rngs)
self.text_decoder = nnx.Linear(20, 10, rngs=rngs)
def __call__(self, x):
return self.text_decoder(self.transformer(self.text_encoder(x)))
def autoregressive_decode(self, x, decode_len: int = 10):
def decode_step(carry):
idx, x = carry
x = self.text_encoder(x)
x = self.transformer(x)
x = self.text_decoder(x)
return idx + 1, x
out = nnx.while_loop(
lambda input: input[0] < decode_len,
decode_step,
(0, x),
)
return out[-1]
@nnx.jit(static_argnames=["decode_len"])
def decode(model: LLM, x: jax.Array, decode_len: int):
return model.autoregressive_decode(x, decode_len)
@nnx.jit
def normal_run(model, x):
return model(x)
if __name__ == "__main__":
import argparse
def parse_bool(x):
if x in ("true", "1", 1, "True"):
return True
return False
parser = argparse.ArgumentParser()
parser.add_argument("--use_remat", default=True, type=parse_bool)
parser.add_argument("--use_scan", default=True, type=parse_bool)
args = parser.parse_args()
llm = LLM(args.use_remat, args.use_scan, nnx.Rngs(2025))
x_key = jax.random.key(0)
x = jax.random.normal(x_key, (1, 10))
y = normal_run(llm, x)
print("NORMAL RUN SUCCESS")
y_ar = decode(llm, x, 10)
print("DECODE RUN SUCCESS")
I tried all 4 options (use_remat=1/0
, use_scan=1/0
). normal_run
works for all combinations, but decode
do not work when either use_remat
or use_scan
is 1.
When use_remat
is True, the error message is
Traceback (most recent call last):
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 120, in <module>
y_ar = decode(llm, x, 10)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 431, in __call__
pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 234, in to_tree
check_consistent_aliasing(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 57, in check_consistent_aliasing
value._check_valid_context(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/object.py", line 302, in _check_valid_context
raise errors.TraceContextError(error_msg())
flax.errors.TraceContextError: Trying to extract graph node from different trace level, got Linear( # Param: 420 (1.7 KB)
bias=Param( # 20 (80 B)
value=Array(shape=(20,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x75bf4f5dcee0>,
dot_general=<function dot_general at 0x75bf4fdfc670>,
dtype=None,
in_features=20,
kernel=Param( # 400 (1.6 KB)
value=Array(shape=(20, 20), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x75bf4e85e0e0>,
out_features=20,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x75bf4e85dbd0>,
use_bias=True
) (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)
When use_scan
is True (regardless of use_remat
), the error message is
Traceback (most recent call last):
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 120, in <module>
y_ar = decode(llm, x, 10)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 431, in __call__
pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/compilation.py", line 126, in __call__
out = self.f(*args, **kwargs)
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 94, in decode
return model.autoregressive_decode(x, decode_len)
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 83, in autoregressive_decode
out = nnx.while_loop(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
return f(*args, **kwargs)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1439, in while_loop
pure_out = jax.lax.while_loop(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
return f(*args, **kwargs)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1380, in __call__
out = self.f(val)
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 79, in decode_step
x = self.transformer(x)
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 23, in __call__
return self._call_scan(x)
File "/mnt/storage_8T/cjy/labs_vla/tests/test_while.py", line 52, in _call_scan
x, _ = run_block(x, self.blocks)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/graph.py", line 2051, in update_context_manager_wrapper
return f(*args, **kwargs)
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/transforms/iteration.py", line 1213, in scan_wrapper
pure_args: tuple = extract.to_tree(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 234, in to_tree
check_consistent_aliasing(
File "/mnt/storage_8T/cjy/venv/p310/lib/python3.10/site-packages/flax/nnx/extract.py", line 62, in check_consistent_aliasing
raise ValueError(
ValueError: Cannot extract graph node from different trace level, got Param( # 160 (640 B)
value=Traced<float32[8,20]>with<DynamicJaxprTrace>
)
I don't know what I am doing wrong. How can I make the decoding works with remat and scan?
I am using jaxlib==0.6.2
, jax==0.6.2
, flax==0.10.7