Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device-to-host transfers are very slow and asymmetrical #8525

Open
neel04 opened this issue Jan 3, 2025 · 1 comment
Open

Device-to-host transfers are very slow and asymmetrical #8525

neel04 opened this issue Jan 3, 2025 · 1 comment

Comments

@neel04
Copy link

neel04 commented Jan 3, 2025

🐛 Bug

Moving data from device-to-host is really slow - atleast by 7-10x compared to JAX.

For a lot of workloads (eg: inference) this latency is crucial, and such subpar performance makes torch:xla simply an unfeasible option for production deployments.

To Reproduce

Steps to reproduce the behavior:

  1. Go to linked Colab.
  2. Run JAX or PyTorch XLA benchmark
  3. Click Runtime -> Disconnect and delete runtime to ensure no interference between frameworks.

Expected behavior

This is the performance offered by JAX:

TPU Transfer Bandwidth Benchmark - JAX
================================================================================
JAX version: 0.4.33
================================================================================

Testing array size: 1 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         1.06         1.08
            JAX   Yes         3.98         1.89

Testing array size: 5 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         2.55         3.07
            JAX   Yes        10.62         7.55

Testing array size: 10 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         3.14         1.81
            JAX   Yes        16.57        11.97

Whereas Torch XLA:

TPU Transfer Bandwidth Benchmark - PyTorch
================================================================================
PyTorch XLA version: 2.5.1+libtpu
================================================================================

Testing array size: 1 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No         4.14         0.66
        PyTorch   Yes         2.54         0.66

Testing array size: 5 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No        15.43         1.43
        PyTorch   Yes        10.85         1.18

Testing array size: 10 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No        12.42         1.35
        PyTorch   Yes         9.43         1.33

Clearly, the Device-to-Host bandwidth is lacking compared to JAX, by 10x in the worst case and 3x in the best.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v2-8, v3-8, v5e-1
  • torch_xla version: 2.5.1+libtpu

Additional context

Metrics analysis backs up this asymmetry:

Metric: TransferToDeviceTime
  TotalSamples: 120
  Accumulator: 005ms294.417us
  ValueRate: 005ms525.322us / second
  Rate: 102.568 / second
  Percentiles: 1%=034.743us; 5%=035.169us; 10%=036.223us; 20%=037.516us; 50%=040.789us; 80%=050.877us; 90%=056.871us; 95%=060.393us; 99%=068.632us

Metric: TransferFromDeviceTime
  TotalSamples: 60
  Accumulator: 138ms583.965us
  ValueRate: 123ms072.651us / second
  Rate: 53.6717 / second
  Percentiles: 1%=709.768us; 5%=723.600us; 10%=736.108us; 20%=754.730us; 50%=002ms148.061us; 80%=004ms010.750us; 90%=004ms035.293us; 95%=004ms112.610us; 99%=004ms169.050us

These are anomalous numbers, considering the (relatively) small sizes of the Tensors.

Additionally, I can confirm that this issue arises when using production-grade models as well wherein such latencies are crippling for good performance.

I would also be curious about why the asymmetry between H2D vs D2H performance. I know D2H would be blocking, but is this an XLA bottleneck wherein its unable to efficiently stream & overlap tiles of computation - but just happens to be more well optimized for H2D transfers?

Happy to provide more details upon request.

@qihqi
Copy link
Collaborator

qihqi commented Jan 3, 2025

Hi @neel04

(not really having an answer for your question)

Given that Jax transfer speed satisfy your need, feel free to explore torch_xla2: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/examples/eager_mode.py

The tensor in torch_xla2 wraps un underlying jax.Array and it should have the same performance characteristics of Jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants