Skip to content

[TORCHAX] Can't allocate random tensors with device='jax' #9411

Open
@vlad-karp

Description

@vlad-karp

🐛 Bug

Can't allocate random tensors with device='jax'

It gives the following error message:

jaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Unable to replace a PyArray with a PyArray from a different client.

To Reproduce

as simple as:

import torch
import torchax
torchax.enable_globally()
torch.randn(3, 3, 28, 28, device='jax')

Expected behavior

random tensor should be created without crashes

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.8.0.dev

Additional context

reverting this PR https://github.com/pytorch/xla/pull/9305/files helps

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions