Open
Description
🐛 Bug
Can't allocate random tensors with device='jax'
It gives the following error message:
jaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Unable to replace a PyArray with a PyArray from a different client.
To Reproduce
as simple as:
import torch
import torchax
torchax.enable_globally()
torch.randn(3, 3, 28, 28, device='jax')
Expected behavior
random tensor should be created without crashes
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: 2.8.0.dev
Additional context
reverting this PR https://github.com/pytorch/xla/pull/9305/files helps