-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Labels
Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)Response within 5 business days. Resolution within 30 days. (Assignee required)
Description
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux, but OS-agnostic
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
:- flax 0.10.6
- jax 0.6.0
- jaxlib 0.6.0 (cuda12)
- Python version: 3.12
- GPU/TPU model and memory: not relevant
- CUDA version (if applicable): 12.6
Problem you have encountered:
nnx.jit()
doesn't seem to fully support backend=...
or device=...
parameters unlike the vanilla jax.jit()
does, because out_shardings
must be UNSPECIFIED when device is specified.
ValueError: If backend or device is specified on jit, then out_shardings should not be specified.
What you expected to happen:
nnx.jit(fn, backend=...)
or nnx.jit(fn, device=...)
should work seamlessly as the vanilla jit()
. The minimal reproduction code below should run and print a [1, 4]
array.
Logs, error messages, etc:
This is probably because nnx's jit wrapper always passes a 3-tuple (jax_in_shardings
, kwargs_shardings
, jax_out_shardings
) to jax.jit
according to the output of JitFn
.
See https://github.com/google/flax/blob/main/flax/nnx/transforms/compilation.py#L379
Steps to reproduce:
A minimal reproduction:
import flax.nnx as nnx
import jax.numpy as jnp
model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
@nnx.jit(backend='cpu')
def foo(model: nnx.Linear, x):
return model(x)
batch_size: int = 2
y = foo(model, jnp.ones([batch_size, 3]))
assert y.shape == (batch_size, 4)
print(y)
Expected: no error
Actual: ValueError: If backend or device is specified on jit, then out_shardings should not be specified.
Metadata
Metadata
Assignees
Labels
Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)Response within 5 business days. Resolution within 30 days. (Assignee required)