Skip to content

Commit c81733f

Browse files
cyanguwaksivaman
authored andcommitted
[PyTorch] Miscellanous fixes for FP8 DPA module (#804)
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]>
1 parent 3c604eb commit c81733f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def reset_global_fp8_state():
7070
def _cudnn_version() -> Tuple[int, int, int]:
7171
"""Runtime cuDNN version (major, minor, patch)"""
7272
encoded_version = ext.get_cudnn_version()
73-
major, encoded_version = divmod(encoded_version, 1000)
73+
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
74+
major, encoded_version = divmod(encoded_version, major_version_magnitude)
7475
minor, patch = divmod(encoded_version, 100)
7576
return (major, minor, patch)
7677

transformer_engine/pytorch/attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,17 @@ def __init__(
27112711
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
27122712
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
27132713

2714+
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
2715+
"""
2716+
Temporarily remove fused_attention._extra_state as a missing key
2717+
when loading older TransformerEngine checkpoints. Will phase out
2718+
this hook in TransformerEngine 2.0.
2719+
"""
2720+
for key in incompatible_keys.missing_keys:
2721+
if 'fused_attention._extra_state' in key:
2722+
incompatible_keys.missing_keys.remove(key)
2723+
self.register_load_state_dict_post_hook(remove_extra_states_check)
2724+
27142725
def get_fp8_weights_scratchpad(
27152726
self,
27162727
is_first_microbatch: Union[bool, None],
@@ -3063,6 +3074,7 @@ def __init__(
30633074
layer_number=layer_number,
30643075
deterministic=self.deterministic,
30653076
**attn_kwargs)
3077+
30663078
self.unfused_attention = UnfusedDotProductAttention(
30673079
norm_factor, **attn_kwargs, layer_number=layer_number)
30683080

0 commit comments

Comments
 (0)