Skip to content

v1.3

Compare
Choose a tag to compare
@ptrendx ptrendx released this 26 Feb 22:11
· 410 commits to main since this release

Release Notes – Release 1.3

Key Features and Enhancements

  • [pyTorch] Added support for deferred parameter initialization in several Transformer Engine modules via the device="meta" parameter:
    Linear
    LayerNorm
    RMSNorm
    LayerNormLinear
    LayerNormMLP
    MultiheadAttention
    TransformerLayer
  • [pyTorch] Added support for CPU offloading of weights and activations for tensors saved for the backward pass for additional memory savings.
  • [pyTorch] Added an additional attn_input_format parameter to TransformerLayer for the layout of the QKV tensor.
  • [pyTorch] Added support for non-tensor values of the forward parameter when using the checkpoint API call.
  • [PaddlePaddle] Added support for sequence parallelism.
  • [PaddlePaddle] Optimized memory usage for pipeline parallel training.
  • [JAX] Added support for grouped query attention (GQA).

Fixed Issues

  • [pyTorch] In LayerNormLinear and Linear, unused copies of weight and bias tensors were not deleted for the case when Q, K, and V tensors are fused.
  • [pyTorch] Faulty usage of pipeline parallelism with the FusedAttention backend.
  • [pyTorch] attention_type was not correctly passed from the MultiheadAttention call to the DotProductAttention call.
  • [pyTorch] Fused DPA backend reported bogus NaN errors during the backward pass.
  • [pyTorch] Crashes when running with PyTorch v2.0.1.
  • [pyTorch] Statistics could be computed incorrectly when training with FP8 in recent versions of pyTorch. For details see #600.
  • [JAX] Crashes when training in FP8 + FSDP.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.
  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep the consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous Changes

FlashAttention v1 is no longer supported in Transformer Engine. The minimum required version is v2.0.6.