Skip to content

PyTorch/XLA 2.2 Release Notes

Compare
Choose a tag to compare
@zpcore zpcore released this 31 Jan 21:24
· 1328 commits to master since this release
053a6f2

Cloud TPUs now support the PyTorch 2.2 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.2 release, this release introduces several features, and PyTorch/XLA specific bug fixes.

Installing PyTorch and PyTorch/XLA 2.2.0 wheel:

pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:

pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
  • Note: If you meet the error RuntimeError: operator torchvision::nms does not exist when using torchvision in the 2.2.0 docker image, please try the following command to fix the issue:
pip uninstall torch -y; pip install torch==2.2.0

Stable Features

PJRT

  • PJRT_DEVICE=GPU has been renamed to PJRT_DEVICE=CUDA (#5754).
    • PJRT_DEVICE=GPU will be removed in the 2.3 release.
  • Optimize Host to Device transfer (#5772) and device to host transfer (#5825).
  • Miscellaneous low-level refactoring and performance improvements (#5799, #5737, #5794, #5793, #5546).

Beta Features

GSPMD

  • Support DTensor API integration and move GSPMD out of experimental (#5776).
  • Enable debug visualization func visualize_tensor_sharding (#5742), added doc.
  • Support mark_shard scalar tensors (#6158).
  • Add apply_backward_optimization_barrier (#6157).

Export

  • Handled lifted constants in torch export (#6111).
  • Run decomp before processing (#5713).
  • Support export to tf.saved_model for models with unused params (#5694).
  • Add an option to not save the weights (#5964).
  • Experimental support for dynamic dimension sizes in torch export to StableHLO (#5790, openxla/xla#6897).

CoreAtenOpSet

  • PyTorch/XLA aims to support all PyTorch core ATen ops in the 2.3 release. We’re actively working on this, remaining issues to be closed can be found at issue list.

Benchmark

  • Support of benchmark running automation and metric report analysis on both TPU and GPU (doc).

Experimental Features

FSDP via SPMD

  • Introduce FSDP via SPMD, or FSDPv2 (#6187). The RFC can be found (#6379).
  • Add FSDPv2 user guide (#6386).

Distributed Op

  • Support all-gather coalescing (#5950).
  • Support reduce-scatter coalescing (#5956).

Persistent Compilation

  • Enable persistent compilation caching (#6065).
  • Document and introduce xr.initialize_cache python API (#6046).

Checkpointing

  • Support auto checkpointing for TPU preemption (#5753).
  • Support Async checkpointing through CheckpointManager (#5697).

Usability

  • Document Compilation/Execution analysis (#6039).
  • Add profiler API for async capture (#5969).

Quantization

  • Lower quant/dequant torch op to StableHLO (#5763).

GPU

  • Document multihost gpu training (#5704).
  • Support multinode training via torchrun (#5657).

Bug Fixes and Improvements

  • Pow precision issue (#6103).
  • Handle negative dim for Diagonal Scatter (#6123).
  • Fix as_strided for inputs smaller than the arguments specification (#5914).
  • Fix squeeze op lowering issue when dim is not in sorted order (#5751).
  • Optimize RNG seed dtype for better memory utilization (#5710).

Lowering

  • _prelu_kernel_backward (#5724).