|
| 1 | +.. |
| 2 | + Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +
|
| 4 | + See LICENSE for license information. |
| 5 | + |
| 6 | +Frequently Asked Questions (FAQ) |
| 7 | +================================ |
| 8 | + |
| 9 | +FP8 checkpoint compatibility |
| 10 | +---------------------------- |
| 11 | + |
| 12 | +Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted. |
| 13 | + |
| 14 | +Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below. |
| 15 | + |
| 16 | +.. code-block:: python |
| 17 | +
|
| 18 | + >>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init |
| 19 | + >>> with fp8_model_init(enabled=True): |
| 20 | + ... mha = MultiheadAttention( |
| 21 | + ... hidden_size=1024, |
| 22 | + ... num_attention_heads=16, |
| 23 | + ... bias=True, |
| 24 | + ... params_dtype=torch.bfloat16, |
| 25 | + ... input_layernorm=False, |
| 26 | + ... fuse_qkv_params=True, |
| 27 | + ... attention_type="self", |
| 28 | + ... qkv_weight_interleaved=True, |
| 29 | + ... ).to(dtype=torch.bfloat16, device="cuda") |
| 30 | + ... |
| 31 | + >>> state_dict = mha.state_dict() |
| 32 | + >>> print(state_dict.keys()) |
| 33 | + odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state']) |
| 34 | +
|
| 35 | +Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions. |
| 36 | + |
| 37 | +.. list-table:: |
| 38 | + |
| 39 | + * - **Version: <= 1.5** |
| 40 | + |
| 41 | + - Saves no FP8 metadata since FP8 attention is not supported |
| 42 | + - Loading behavior for checkpoints created by the following versions: |
| 43 | + |
| 44 | + :<= 1.5: Loads no FP8 metadata |
| 45 | + :> 1.5: Error: unexpected key |
| 46 | + * - **Version: 1.6, 1.7** |
| 47 | + |
| 48 | + - Saves FP8 metadata to `core_attention.fused_attention._extra_state` |
| 49 | + - Loading behavior for checkpoints created by the following versions: |
| 50 | + |
| 51 | + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes |
| 52 | + :1.6, 1.7: Loads FP8 metadata from checkpoint |
| 53 | + :>= 1.8: Error: unexpected key |
| 54 | + * - **Version: >=1.8, <= 1.11** |
| 55 | + |
| 56 | + - Saves FP8 metadata to `core_attention._extra_state` |
| 57 | + - Loading behavior for checkpoints created by the following versions: |
| 58 | + |
| 59 | + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes |
| 60 | + :1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by |
| 61 | + |
| 62 | + .. code-block:: python |
| 63 | +
|
| 64 | + >>> state_dict["core_attention._extra_state"] = \ |
| 65 | + state_dict["core_attention.fused_attention._extra_state"] |
| 66 | + >>> del state_dict["core_attention.fused_attention._extra_state"] |
| 67 | +
|
| 68 | + :>= 1.8: Loads FP8 metadata from checkpoint |
| 69 | + * - **Version: >=1.12** |
| 70 | + |
| 71 | + - Saves FP8 metadata to `core_attention._extra_state` |
| 72 | + - Loading behavior for checkpoints created by the following versions: |
| 73 | + |
| 74 | + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes |
| 75 | + :>= 1.6: Loads FP8 metadata from checkpoint |
0 commit comments