Skip to content

Commit 7f1ff55

Browse files
committed
Improvements in attention_forward functions
1 parent dd16acb commit 7f1ff55

35 files changed

+1810
-400
lines changed

src/transformers/integrations/flash_attention.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,16 @@ def flash_attention_forward(
2020
sliding_window: Optional[int] = None,
2121
softcap: Optional[float] = None,
2222
**kwargs,
23-
) -> Tuple[torch.Tensor, None]:
23+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
24+
"""
25+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
26+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
27+
`num_key_value_groups <= num_heads` and
28+
`num_heads % num_key_value_groups == 0`.
29+
30+
"""
2431
# This is before the transpose
25-
seq_len = query.shape[2]
32+
q_len = query.shape[2]
2633

2734
# FA2 uses non-transposed inputs
2835
query = query.transpose(1, 2)
@@ -52,7 +59,7 @@ def flash_attention_forward(
5259
key,
5360
value,
5461
attention_mask,
55-
query_length=seq_len,
62+
query_length=q_len,
5663
is_causal=module.is_causal,
5764
dropout=dropout,
5865
softmax_scale=scaling,
@@ -62,5 +69,6 @@ def flash_attention_forward(
6269
target_dtype=target_dtype,
6370
**kwargs,
6471
)
72+
# attn_output: (batch, q_len, num_heads, head_dim)
6573

6674
return attn_output, None

src/transformers/integrations/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ def flex_attention_forward(
2020
head_mask: Optional[torch.Tensor] = None,
2121
**kwargs,
2222
) -> Tuple[torch.Tensor, torch.Tensor]:
23+
"""
24+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
25+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
26+
`num_key_value_groups <= num_heads` and
27+
`num_heads % num_key_value_groups == 0`.
28+
29+
"""
2330
causal_mask = attention_mask
2431
if causal_mask is not None:
2532
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

src/transformers/integrations/sdpa_attention.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,40 +25,60 @@ def sdpa_attention_forward(
2525
scaling: Optional[float] = None,
2626
is_causal: Optional[bool] = None,
2727
**kwargs,
28-
) -> Tuple[torch.Tensor, None]:
29-
if hasattr(module, "num_key_value_groups"):
30-
key = repeat_kv(key, module.num_key_value_groups)
31-
value = repeat_kv(value, module.num_key_value_groups)
28+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
29+
"""
30+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
31+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
32+
`num_key_value_groups <= num_heads` and
33+
`num_heads % num_key_value_groups == 0`.
34+
35+
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
36+
`scaled_dot_product_attention` is supposed to support
37+
`num_key_value_groups < num_heads`. But at least up to PyTorch 2.5.1, this
38+
does not seem to be supported.
3239
40+
"""
41+
assert query.ndim == key.ndim == value.ndim == 4
42+
_, num_heads, q_len, _ = query.shape
43+
_, num_key_value_groups, kv_len, _ = key.shape
44+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
45+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
46+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
47+
q_per_kv = num_heads // num_key_value_groups
3348
causal_mask = attention_mask
3449
if attention_mask is not None:
35-
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
36-
37-
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
38-
# Reference: https://github.com/pytorch/pytorch/issues/112577.
39-
query = query.contiguous()
40-
key = key.contiguous()
41-
value = value.contiguous()
42-
50+
causal_mask = causal_mask[:, :, :, :kv_len]
4351
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
4452
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
4553
if is_causal is None:
46-
is_causal = causal_mask is None and query.shape[2] > 1
54+
is_causal = causal_mask is None and q_len > 1
4755

4856
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
4957
# We convert it to a bool for the SDPA kernel that only accepts bools.
5058
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
5159
is_causal = is_causal.item()
5260

61+
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
62+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
63+
query = query.contiguous()
64+
if q_per_kv == 1:
65+
key = key.contiguous()
66+
value = value.contiguous()
67+
else:
68+
# TODO: Once PyTorch SDPA supports q_per_kv > 1, this should be
69+
# removed!
70+
key = repeat_kv(key, n_rep=q_per_kv).contiguous()
71+
value = repeat_kv(value, n_rep=q_per_kv).contiguous()
72+
5373
attn_output = torch.nn.functional.scaled_dot_product_attention(
54-
query,
55-
key,
56-
value,
74+
query=query,
75+
key=key,
76+
value=value,
5777
attn_mask=causal_mask,
5878
dropout_p=dropout,
5979
scale=scaling,
6080
is_causal=is_causal,
6181
)
6282
attn_output = attn_output.transpose(1, 2).contiguous()
63-
83+
# attn_output: (batch, q_len, num_heads, head_dim)
6484
return attn_output, None

src/transformers/models/aria/configuration_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@
2222

2323
from ...configuration_utils import PretrainedConfig
2424
from ...modeling_rope_utils import rope_config_validation
25+
from ...utils.import_utils import is_torch_available
2526
from ..auto import CONFIG_MAPPING, AutoConfig
2627

2728

29+
if is_torch_available():
30+
pass
31+
32+
2833
class AriaTextConfig(PretrainedConfig):
2934
r"""
3035
This class handles the configuration for the text component of the Aria model.

src/transformers/models/aria/image_processing_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
validate_preprocess_arguments,
3838
)
3939
from ...utils import TensorType
40+
from ...utils.import_utils import is_torch_available
41+
42+
43+
if is_torch_available():
44+
pass
4045

4146

4247
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:

src/transformers/models/aria/modeling_aria.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
if is_torch_available():
4747
import torch
48+
import torch.nn.functional as F
4849
from torch import nn
4950

5051

@@ -346,7 +347,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
346347
"""
347348
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
348349
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
349-
fc1_output = nn.functional.silu(projection) * gate
350+
fc1_output = F.silu(projection) * gate
350351
fc2_output = self.fc2(fc1_output, tokens_per_expert)
351352
return fc2_output
352353

@@ -395,7 +396,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
395396
# Top K Routing
396397
logits = self.router(hidden_states)
397398
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
398-
scores = nn.functional.softmax(top_logits, dim=-1)
399+
scores = F.softmax(top_logits, dim=-1)
399400

400401
original_dtype = top_indices.dtype
401402

@@ -465,16 +466,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
465466
return q_embed, k_embed
466467

467468

468-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
469-
"""
470-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
471-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
472-
"""
473-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
474-
if n_rep == 1:
475-
return hidden_states
476-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
477-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
469+
def _attention_compute_scores(
470+
query: torch.Tensor,
471+
key: torch.Tensor,
472+
) -> torch.Tensor:
473+
nh_q = query.shape[1]
474+
nh_k = key.shape[1]
475+
# - query: (bs, nh_q, T_q, hs)
476+
# - key: (bs, nh_k, T_k, hs)
477+
q_per_kv = nh_q // nh_k
478+
key_transposed = key.mT # (bs, nh_k, hs, T_k)
479+
if q_per_kv == 1:
480+
return query @ key_transposed
481+
else:
482+
assert q_per_kv > 1
483+
if nh_k > 1:
484+
q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:]
485+
_query = query.view(*q_shape)
486+
key_transposed = key_transposed.unsqueeze(2)
487+
else:
488+
_query = query
489+
# At this point:
490+
# - _query: (bs, nh_k, q_per_kv, T_q, hs)
491+
# - key_transposed: (bs, nh_k, 1, hs, T_k)
492+
# - scores: (bs, nh_k, q_per_kv, T_q, T_k) -> (bs, nh_q, T_q, T_k)
493+
scores = torch.matmul(_query, key_transposed)
494+
s_shape = query.shape[:-1] + (key.shape[2],)
495+
return scores.view(*s_shape)
496+
497+
498+
def _attention_compute_weighted_values(
499+
scores: torch.Tensor,
500+
value: torch.Tensor,
501+
) -> torch.Tensor:
502+
nh_q = scores.shape[1]
503+
nh_k = value.shape[1]
504+
# - scores: (bs, nh_q, T_q, T_k)
505+
# - value: (bs, nh_k, T_k, hs)
506+
q_per_kv = nh_q // nh_k
507+
if q_per_kv == 1:
508+
return scores @ value
509+
else:
510+
if nh_k > 1:
511+
s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:]
512+
_scores = scores.view(*s_shape)
513+
_value = value.unsqueeze(2)
514+
else:
515+
_scores = scores
516+
_value = value
517+
# At this point:
518+
# - _scores: (bs, nh_k, q_per_kv, T_q, T_k)
519+
# - _value: (bs, nh_k, 1, T_k, hs)
520+
# - result: (bs, nh_k, q_per_kv, T_q, hs) -> (bs, nh_q, T_q, hs)
521+
result = torch.matmul(_scores, _value)
522+
r_shape = scores.shape[:-1] + (value.shape[-1],)
523+
return result.view(*r_shape)
478524

479525

480526
def eager_attention_forward(
@@ -486,19 +532,32 @@ def eager_attention_forward(
486532
scaling: float,
487533
dropout: float = 0.0,
488534
**kwargs,
489-
):
490-
key_states = repeat_kv(key, module.num_key_value_groups)
491-
value_states = repeat_kv(value, module.num_key_value_groups)
535+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
536+
"""
537+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
538+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
539+
`num_key_value_groups <= num_heads` and
540+
`num_heads % num_key_value_groups == 0`.
492541
493-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
542+
"""
543+
assert query.ndim == key.ndim == value.ndim == 4
544+
_, num_heads, q_len, _ = query.shape
545+
_, num_key_value_groups, kv_len, _ = key.shape
546+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
547+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
548+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
549+
550+
attn_weights = _attention_compute_scores(query, key) * scaling
494551
if attention_mask is not None:
495-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
552+
causal_mask = attention_mask[:, :, :, :kv_len]
496553
attn_weights = attn_weights + causal_mask
497554

498-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
499-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
500-
attn_output = torch.matmul(attn_weights, value_states)
555+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
556+
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
557+
attn_output = _attention_compute_weighted_values(attn_weights, value)
501558
attn_output = attn_output.transpose(1, 2).contiguous()
559+
# attn_output: (batch, q_len, num_heads, head_dim)
560+
# attn_weights: (batch, num_heads, q_len, kv_len)
502561

503562
return attn_output, attn_weights
504563

@@ -969,6 +1028,7 @@ def forward(
9691028
use_cache,
9701029
cache_position,
9711030
position_embeddings,
1031+
**flash_attn_kwargs,
9721032
)
9731033
else:
9741034
layer_outputs = decoder_layer(

src/transformers/models/aria/modular_aria.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
if is_torch_available():
6868
import torch
69+
import torch.nn.functional as F
6970
from torch import nn
7071

7172

@@ -1094,7 +1095,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
10941095
"""
10951096
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
10961097
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1097-
fc1_output = nn.functional.silu(projection) * gate
1098+
fc1_output = F.silu(projection) * gate
10981099
fc2_output = self.fc2(fc1_output, tokens_per_expert)
10991100
return fc2_output
11001101

@@ -1143,7 +1144,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
11431144
# Top K Routing
11441145
logits = self.router(hidden_states)
11451146
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
1146-
scores = nn.functional.softmax(top_logits, dim=-1)
1147+
scores = F.softmax(top_logits, dim=-1)
11471148

11481149
original_dtype = top_indices.dtype
11491150

src/transformers/models/aria/processing_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@
2525
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
2626
from ...tokenization_utils import PreTokenizedInput, TextInput
2727
from ...utils import TensorType
28+
from ...utils.import_utils import is_torch_available
2829
from ..auto import AutoTokenizer
2930

3031

32+
if is_torch_available():
33+
pass
34+
35+
3136
class AriaProcessorKwargs(ProcessingKwargs, total=False):
3237
_defaults = {
3338
"text_kwargs": {

0 commit comments

Comments
 (0)