Open
Description
🐛 Bug
Tensors with sharding apllied can't be moved to jax device after initialization inside FlaxNNModule.
Once sharded, torchax tensor's jax_device attribute becomes NamedSharding object and an attempt to call _to_copy() calls jax_device device's platform attribute which is not present for NamedSharding objects
To Reproduce
import torch
import jax
import jax.numpy as jnp
from typing import Any
import torchax as tx
import flax.linen as nn
from torchax.flax import FlaxNNModule
from jax.sharding import Mesh, NamedSharding, PartitionSpec
num_embeddings=100000
features_dim=128
batch_size=256
tx.enable_performance_mode()
tx.enable_globally()
env = tx.default_env()
env.config.use_torch_native_for_cpu_tensor = False
print(f"Local devices num: {jax.local_device_count()}")
num_embeddings=100000
features_dim=128
batch_size=256
with tx.jax_device("tpu"):
with env:
nnx_emb = nn.Embed(num_embeddings=num_embeddings, features=features_dim)
sample_input = jnp.ones((batch_size, batch_size), dtype=jnp.int32)
global_devices = jax.devices("tpu")
mesh_1d = jax.sharding.Mesh(global_devices, "data")
sharding = NamedSharding(mesh_1d, PartitionSpec("data"))
orig_init_fn = nnx_emb.init
def sharded_init(prng, *sample_args, **sample_kwargs):
return jax.jit(orig_init_fn, out_shardings=sharding)(
prng, *sample_args, **sample_kwargs
)
nnx_emb.init = sharded_init # type: ignore
emb_module = FlaxNNModule(env, nnx_emb, (sample_input,))
emb_module.to('jax')
Steps to reproduce the behavior:
- Run the code above to get the following exception:
File ".../xla/torchax/torchax/tensor.py", line 445, in _to_copy current_platform = the_tensor.jax_device.platform AttributeError: 'NamedSharding' object has no attribute 'platform'
Expected behavior
module should be successfully moved to jax device
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: 2.8.0.dev