Skip to content

Commit da55d24

Browse files
authored
Disable FAv2.1+ for causal mask in cross attention (#522)
* disable FAv2.1 if causal+cross attn Signed-off-by: Charlene Yang <[email protected]> * remove comment and add warning Signed-off-by: Charlene Yang <[email protected]> * include both causal and padding+causal Signed-off-by: Charlene Yang <[email protected]> * add a space Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]>
1 parent 1508821 commit da55d24

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_flash_attn_version = packaging.version.Version(version("flash-attn"))
5757
_flash_attn_version_required = packaging.version.Version("1.0.6")
5858
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
59+
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
5960

6061
if _flash_attn_2_available:
6162
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
@@ -2134,6 +2135,16 @@ def forward(
21342135
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
21352136
use_flash_attention = False
21362137

2138+
if (_flash_attn_2_1_plus
2139+
and causal_mask
2140+
and max_seqlen_q != max_seqlen_kv):
2141+
warnings.warn(
2142+
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
2143+
"for causal mask in cross attention. See "
2144+
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
2145+
)
2146+
use_flash_attention = False
2147+
21372148
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
21382149
use_flash_attention = False
21392150

0 commit comments

Comments
 (0)