Skip to content

nnx.jit() cannot specify backend or device due to out_shardings #4774

@wookayin

Description

@wookayin

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux, but OS-agnostic
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
    • flax 0.10.6
    • jax 0.6.0
    • jaxlib 0.6.0 (cuda12)
  • Python version: 3.12
  • GPU/TPU model and memory: not relevant
  • CUDA version (if applicable): 12.6

Problem you have encountered:

nnx.jit() doesn't seem to fully support backend=... or device=... parameters unlike the vanilla jax.jit() does, because out_shardings must be UNSPECIFIED when device is specified.

ValueError: If backend or device is specified on jit, then out_shardings should not be specified.

What you expected to happen:

nnx.jit(fn, backend=...) or nnx.jit(fn, device=...) should work seamlessly as the vanilla jit(). The minimal reproduction code below should run and print a [1, 4] array.

Logs, error messages, etc:

This is probably because nnx's jit wrapper always passes a 3-tuple (jax_in_shardings, kwargs_shardings, jax_out_shardings) to jax.jit according to the output of JitFn.

See https://github.com/google/flax/blob/main/flax/nnx/transforms/compilation.py#L379

Steps to reproduce:

A minimal reproduction:

import flax.nnx as nnx
import jax.numpy as jnp

model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))


@nnx.jit(backend='cpu')
def foo(model: nnx.Linear, x):
    return model(x)


batch_size: int = 2
y = foo(model, jnp.ones([batch_size, 3]))
assert y.shape == (batch_size, 4)
print(y)

Expected: no error
Actual: ValueError: If backend or device is specified on jit, then out_shardings should not be specified.

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions