PyTorch/XLA 2.2 Release Notes
Cloud TPUs now support the PyTorch 2.2 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.2 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Installing PyTorch and PyTorch/XLA 2.2.0 wheel:
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- Note: If you meet the error
RuntimeError: operator torchvision::nms does not exist
when using torchvision in the 2.2.0 docker image, please try the following command to fix the issue:
pip uninstall torch -y; pip install torch==2.2.0
Stable Features
PJRT
PJRT_DEVICE=GPU
has been renamed toPJRT_DEVICE=CUDA
(#5754).PJRT_DEVICE=GPU
will be removed in the 2.3 release.
- Optimize Host to Device transfer (#5772) and device to host transfer (#5825).
- Miscellaneous low-level refactoring and performance improvements (#5799, #5737, #5794, #5793, #5546).
Beta Features
GSPMD
- Support DTensor API integration and move GSPMD out of experimental (#5776).
- Enable debug visualization func
visualize_tensor_sharding
(#5742), added doc. - Support
mark_shard
scalar tensors (#6158). - Add
apply_backward_optimization_barrier
(#6157).
Export
- Handled lifted constants in torch export (#6111).
- Run decomp before processing (#5713).
- Support export to
tf.saved_model
for models with unused params (#5694). - Add an option to not save the weights (#5964).
- Experimental support for dynamic dimension sizes in torch export to StableHLO (#5790, openxla/xla#6897).
CoreAtenOpSet
- PyTorch/XLA aims to support all PyTorch core ATen ops in the 2.3 release. We’re actively working on this, remaining issues to be closed can be found at issue list.
Benchmark
- Support of benchmark running automation and metric report analysis on both TPU and GPU (doc).
Experimental Features
FSDP via SPMD
- Introduce FSDP via SPMD, or FSDPv2 (#6187). The RFC can be found (#6379).
- Add FSDPv2 user guide (#6386).
Distributed Op
Persistent Compilation
- Enable persistent compilation caching (#6065).
- Document and introduce
xr.initialize_cache
python API (#6046).
Checkpointing
- Support auto checkpointing for TPU preemption (#5753).
- Support Async checkpointing through CheckpointManager (#5697).
Usability
Quantization
- Lower quant/dequant torch op to StableHLO (#5763).
GPU
Bug Fixes and Improvements
- Pow precision issue (#6103).
- Handle negative dim for Diagonal Scatter (#6123).
- Fix
as_strided
for inputs smaller than the arguments specification (#5914). - Fix squeeze op lowering issue when dim is not in sorted order (#5751).
- Optimize RNG seed dtype for better memory utilization (#5710).
Lowering
_prelu_kernel_backward
(#5724).