diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 40cfdd34b7..caba385d46 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -70,7 +70,8 @@ def reset_global_fp8_state(): def _cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" encoded_version = ext.get_cudnn_version() - major, encoded_version = divmod(encoded_version, 1000) + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3bf4598fc1..af6c151cab 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2711,6 +2711,17 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ + for key in incompatible_keys.missing_keys: + if 'fused_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -3063,6 +3074,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs) + self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number)