Skip to content

[TORCHAX] Sharded jax tensors support for FlaxNNModule #9413

Open
@vlad-karp

Description

@vlad-karp

🐛 Bug

Tensors with sharding apllied can't be moved to jax device after initialization inside FlaxNNModule.
Once sharded, torchax tensor's jax_device attribute becomes NamedSharding object and an attempt to call _to_copy() calls jax_device device's platform attribute which is not present for NamedSharding objects

To Reproduce

import torch
import jax
import jax.numpy as jnp
from typing import Any
import torchax as tx
import flax.linen as nn
from torchax.flax import FlaxNNModule
from jax.sharding import Mesh, NamedSharding, PartitionSpec

num_embeddings=100000
features_dim=128
batch_size=256

tx.enable_performance_mode()
tx.enable_globally()
env = tx.default_env()
env.config.use_torch_native_for_cpu_tensor = False
print(f"Local devices num: {jax.local_device_count()}")

num_embeddings=100000
features_dim=128
batch_size=256
with tx.jax_device("tpu"):
    with env:
        nnx_emb = nn.Embed(num_embeddings=num_embeddings, features=features_dim)
        sample_input = jnp.ones((batch_size, batch_size), dtype=jnp.int32)
        global_devices = jax.devices("tpu")
        mesh_1d = jax.sharding.Mesh(global_devices, "data")
        sharding = NamedSharding(mesh_1d, PartitionSpec("data"))
        orig_init_fn = nnx_emb.init
        def sharded_init(prng, *sample_args, **sample_kwargs):
            return jax.jit(orig_init_fn, out_shardings=sharding)(
            prng, *sample_args, **sample_kwargs
        )
        nnx_emb.init = sharded_init  # type: ignore
        emb_module = FlaxNNModule(env, nnx_emb, (sample_input,))
        emb_module.to('jax')

Steps to reproduce the behavior:

  1. Run the code above to get the following exception:
    File ".../xla/torchax/torchax/tensor.py", line 445, in _to_copy current_platform = the_tensor.jax_device.platform AttributeError: 'NamedSharding' object has no attribute 'platform'

Expected behavior

module should be successfully moved to jax device

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.8.0.dev

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions