Skip to content

Commit

Permalink
Disable FAv2.1+ for causal mask in cross attention (#522)
Browse files Browse the repository at this point in the history
* disable FAv2.1 if causal+cross attn

Signed-off-by: Charlene Yang <[email protected]>

* remove comment and add warning

Signed-off-by: Charlene Yang <[email protected]>

* include both causal and padding+causal

Signed-off-by: Charlene Yang <[email protected]>

* add a space

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Nov 17, 2023
1 parent 1508821 commit da55d24
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")

if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
Expand Down Expand Up @@ -2134,6 +2135,16 @@ def forward(
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False

if (_flash_attn_2_1_plus
and causal_mask
and max_seqlen_q != max_seqlen_kv):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False

if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False

Expand Down

0 comments on commit da55d24

Please sign in to comment.