Description
🐛 Bug
Moving data from device-to-host is really slow - atleast by 7-10x compared to JAX.
For a lot of workloads (eg: inference) this latency is crucial, and such subpar performance makes torch:xla
simply an unfeasible option for production deployments.
To Reproduce
Steps to reproduce the behavior:
- Go to linked Colab.
- Run JAX or PyTorch XLA benchmark
- Click
Runtime
->Disconnect and delete runtime
to ensure no interference between frameworks.
Expected behavior
This is the performance offered by JAX:
TPU Transfer Bandwidth Benchmark - JAX
================================================================================
JAX version: 0.4.33
================================================================================
Testing array size: 1 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
JAX No 1.06 1.08
JAX Yes 3.98 1.89
Testing array size: 5 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
JAX No 2.55 3.07
JAX Yes 10.62 7.55
Testing array size: 10 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
JAX No 3.14 1.81
JAX Yes 16.57 11.97
Whereas Torch XLA:
TPU Transfer Bandwidth Benchmark - PyTorch
================================================================================
PyTorch XLA version: 2.5.1+libtpu
================================================================================
Testing array size: 1 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
PyTorch No 4.14 0.66
PyTorch Yes 2.54 0.66
Testing array size: 5 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
PyTorch No 15.43 1.43
PyTorch Yes 10.85 1.18
Testing array size: 10 MB
--------------------------------------------------------------------------------
Framework JIT H2D (GB/s) D2H (GB/s)
--------------------------------------------------------------------------------
PyTorch No 12.42 1.35
PyTorch Yes 9.43 1.33
Clearly, the Device-to-Host bandwidth is lacking compared to JAX, by 10x in the worst case and 3x in the best.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v2-8, v3-8, v5e-1
- torch_xla version:
2.5.1+libtpu
Additional context
Metrics analysis backs up this asymmetry:
Metric: TransferToDeviceTime
TotalSamples: 120
Accumulator: 005ms294.417us
ValueRate: 005ms525.322us / second
Rate: 102.568 / second
Percentiles: 1%=034.743us; 5%=035.169us; 10%=036.223us; 20%=037.516us; 50%=040.789us; 80%=050.877us; 90%=056.871us; 95%=060.393us; 99%=068.632us
Metric: TransferFromDeviceTime
TotalSamples: 60
Accumulator: 138ms583.965us
ValueRate: 123ms072.651us / second
Rate: 53.6717 / second
Percentiles: 1%=709.768us; 5%=723.600us; 10%=736.108us; 20%=754.730us; 50%=002ms148.061us; 80%=004ms010.750us; 90%=004ms035.293us; 95%=004ms112.610us; 99%=004ms169.050us
These are anomalous numbers, considering the (relatively) small sizes of the Tensors.
Additionally, I can confirm that this issue arises when using production-grade models as well wherein such latencies are crippling for good performance.
I would also be curious about why the asymmetry between H2D vs D2H
performance. I know D2H
would be blocking, but is this an XLA bottleneck wherein its unable to efficiently stream & overlap tiles of computation - but just happens to be more well optimized for H2D
transfers?
Happy to provide more details upon request.