Skip to content

Device-to-host transfers are very slow and asymmetrical #8525

Open
@neel04

Description

@neel04

🐛 Bug

Moving data from device-to-host is really slow - atleast by 7-10x compared to JAX.

For a lot of workloads (eg: inference) this latency is crucial, and such subpar performance makes torch:xla simply an unfeasible option for production deployments.

To Reproduce

Steps to reproduce the behavior:

  1. Go to linked Colab.
  2. Run JAX or PyTorch XLA benchmark
  3. Click Runtime -> Disconnect and delete runtime to ensure no interference between frameworks.

Expected behavior

This is the performance offered by JAX:

TPU Transfer Bandwidth Benchmark - JAX
================================================================================
JAX version: 0.4.33
================================================================================

Testing array size: 1 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         1.06         1.08
            JAX   Yes         3.98         1.89

Testing array size: 5 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         2.55         3.07
            JAX   Yes        10.62         7.55

Testing array size: 10 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
            JAX    No         3.14         1.81
            JAX   Yes        16.57        11.97

Whereas Torch XLA:

TPU Transfer Bandwidth Benchmark - PyTorch
================================================================================
PyTorch XLA version: 2.5.1+libtpu
================================================================================

Testing array size: 1 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No         4.14         0.66
        PyTorch   Yes         2.54         0.66

Testing array size: 5 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No        15.43         1.43
        PyTorch   Yes        10.85         1.18

Testing array size: 10 MB
--------------------------------------------------------------------------------
      Framework   JIT   H2D (GB/s)   D2H (GB/s)
--------------------------------------------------------------------------------
        PyTorch    No        12.42         1.35
        PyTorch   Yes         9.43         1.33

Clearly, the Device-to-Host bandwidth is lacking compared to JAX, by 10x in the worst case and 3x in the best.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v2-8, v3-8, v5e-1
  • torch_xla version: 2.5.1+libtpu

Additional context

Metrics analysis backs up this asymmetry:

Metric: TransferToDeviceTime
  TotalSamples: 120
  Accumulator: 005ms294.417us
  ValueRate: 005ms525.322us / second
  Rate: 102.568 / second
  Percentiles: 1%=034.743us; 5%=035.169us; 10%=036.223us; 20%=037.516us; 50%=040.789us; 80%=050.877us; 90%=056.871us; 95%=060.393us; 99%=068.632us

Metric: TransferFromDeviceTime
  TotalSamples: 60
  Accumulator: 138ms583.965us
  ValueRate: 123ms072.651us / second
  Rate: 53.6717 / second
  Percentiles: 1%=709.768us; 5%=723.600us; 10%=736.108us; 20%=754.730us; 50%=002ms148.061us; 80%=004ms010.750us; 90%=004ms035.293us; 95%=004ms112.610us; 99%=004ms169.050us

These are anomalous numbers, considering the (relatively) small sizes of the Tensors.

Additionally, I can confirm that this issue arises when using production-grade models as well wherein such latencies are crippling for good performance.

I would also be curious about why the asymmetry between H2D vs D2H performance. I know D2H would be blocking, but is this an XLA bottleneck wherein its unable to efficiently stream & overlap tiles of computation - but just happens to be more well optimized for H2D transfers?

Happy to provide more details upon request.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions