Skip to content

Commit

Permalink
[PyTorch] Miscellanous fixes for FP8 DPA module (#804)
Browse files Browse the repository at this point in the history
* initialize tp_group for FP8 DPA

Signed-off-by: Charlene Yang <[email protected]>

* fix cuDNN version in unit tests for cuDNN v9

Signed-off-by: Charlene Yang <[email protected]>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints

Signed-off-by: Charlene Yang <[email protected]>

* remove test and redundant implementation from last commit

Signed-off-by: Charlene Yang <[email protected]>

* remove warning message and replace with docstring

Signed-off-by: Charlene Yang <[email protected]>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group

Signed-off-by: Charlene Yang <[email protected]>

* move core_attention.fused_attention._extra_state to core_attention._extra_state

Signed-off-by: Charlene Yang <[email protected]>

* simplify post_state_dict_hooks between FU and DPA

Signed-off-by: Charlene Yang <[email protected]>

* add temporary test

Signed-off-by: Charlene Yang <[email protected]>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test

Signed-off-by: Charlene Yang <[email protected]>

* remove the test

Signed-off-by: Charlene Yang <[email protected]>

* disable pylint self arg for hook which is required by hook

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
  • Loading branch information
cyanguwa authored and ksivaman committed May 2, 2024
1 parent 3c604eb commit c81733f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def reset_global_fp8_state():
def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major, encoded_version = divmod(encoded_version, 1000)
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)

Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,17 @@ def __init__(
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if 'fused_attention._extra_state' in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)

def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
Expand Down Expand Up @@ -3063,6 +3074,7 @@ def __init__(
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)

self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)

Expand Down

0 comments on commit c81733f

Please sign in to comment.