Skip to content

Commit 509a3f4

Browse files
committed
Improvements in attention_forward functions
1 parent b7fc2da commit 509a3f4

40 files changed

+429
-915
lines changed

src/transformers/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class PretrainedConfig(PushToHubMixin):
6767
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
6868
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
6969
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
70+
- **all_sub_configs_have_defaults** (`bool`) -- In general, if `is_composition == True`, the config object
71+
must be initialized passing at least one of the sub-configs. But if
72+
`all_sub_configs_have_defaults == True`, all sub-configs have defaults, so the config can be created
73+
without arguments.
7074
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
7175
outputs of the model during inference.
7276
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
@@ -197,6 +201,7 @@ class PretrainedConfig(PushToHubMixin):
197201
base_config_key: str = ""
198202
sub_configs: dict[str, "PretrainedConfig"] = {}
199203
is_composition: bool = False
204+
all_sub_configs_have_defaults: bool = False
200205
attribute_map: dict[str, str] = {}
201206
base_model_tp_plan: Optional[dict[str, Any]] = None
202207
base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
7+
def _attention_compute_scores(
8+
query: torch.Tensor,
9+
key: torch.Tensor,
10+
) -> torch.Tensor:
11+
nh_q = query.shape[1]
12+
nh_k = key.shape[1]
13+
# - query: (bs, nh_q, T_q, hs)
14+
# - key: (bs, nh_k, T_k, hs)
15+
q_per_kv = nh_q // nh_k
16+
key_transposed = key.mT # (bs, nh_k, hs, T_k)
17+
if q_per_kv == 1:
18+
return query @ key_transposed
19+
else:
20+
assert q_per_kv > 1
21+
if nh_k > 1:
22+
q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:]
23+
_query = query.view(*q_shape)
24+
key_transposed = key_transposed.unsqueeze(2)
25+
else:
26+
_query = query
27+
# At this point:
28+
# - _query: (bs, nh_k, q_per_kv, T_q, hs)
29+
# - key_transposed: (bs, nh_k, 1, hs, T_k)
30+
# - scores: (bs, nh_k, q_per_kv, T_q, T_k) -> (bs, nh_q, T_q, T_k)
31+
scores = torch.matmul(_query, key_transposed)
32+
s_shape = query.shape[:-1] + (key.shape[2],)
33+
return scores.view(*s_shape)
34+
35+
36+
def _attention_compute_weighted_values(
37+
scores: torch.Tensor,
38+
value: torch.Tensor,
39+
) -> torch.Tensor:
40+
nh_q = scores.shape[1]
41+
nh_k = value.shape[1]
42+
# - scores: (bs, nh_q, T_q, T_k)
43+
# - value: (bs, nh_k, T_k, hs)
44+
q_per_kv = nh_q // nh_k
45+
if q_per_kv == 1:
46+
return scores @ value
47+
else:
48+
if nh_k > 1:
49+
s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:]
50+
_scores = scores.view(*s_shape)
51+
_value = value.unsqueeze(2)
52+
else:
53+
_scores = scores
54+
_value = value
55+
# At this point:
56+
# - _scores: (bs, nh_k, q_per_kv, T_q, T_k)
57+
# - _value: (bs, nh_k, 1, T_k, hs)
58+
# - result: (bs, nh_k, q_per_kv, T_q, hs) -> (bs, nh_q, T_q, hs)
59+
result = torch.matmul(_scores, _value)
60+
r_shape = scores.shape[:-1] + (value.shape[-1],)
61+
return result.view(*r_shape)
62+
63+
64+
def eager_attention_forward(
65+
module: torch.nn.Module,
66+
query: torch.Tensor,
67+
key: torch.Tensor,
68+
value: torch.Tensor,
69+
attention_mask: Optional[torch.Tensor],
70+
scaling: float,
71+
dropout: float = 0.0,
72+
**kwargs,
73+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74+
"""
75+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
76+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
77+
`num_key_value_groups <= num_heads` and
78+
`num_heads % num_key_value_groups == 0`.
79+
80+
"""
81+
assert query.ndim == key.ndim == value.ndim == 4
82+
_, num_heads, q_len, _ = query.shape
83+
_, num_key_value_groups, kv_len, _ = key.shape
84+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
85+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
86+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
87+
88+
attn_weights = _attention_compute_scores(query, key) * scaling
89+
if attention_mask is not None:
90+
causal_mask = attention_mask[:, :, :, :kv_len]
91+
attn_weights = attn_weights + causal_mask
92+
93+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
94+
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
95+
attn_output = _attention_compute_weighted_values(attn_weights, value)
96+
attn_output = attn_output.transpose(1, 2).contiguous()
97+
# attn_output: (batch, q_len, num_heads, head_dim)
98+
# attn_weights: (batch, num_heads, q_len, kv_len)
99+
100+
return attn_output, attn_weights

src/transformers/integrations/flash_attention.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,16 @@ def flash_attention_forward(
2020
sliding_window: Optional[int] = None,
2121
softcap: Optional[float] = None,
2222
**kwargs,
23-
) -> Tuple[torch.Tensor, None]:
23+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
24+
"""
25+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
26+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
27+
`num_key_value_groups <= num_heads` and
28+
`num_heads % num_key_value_groups == 0`.
29+
30+
"""
2431
# This is before the transpose
25-
seq_len = query.shape[2]
32+
q_len = query.shape[2]
2633

2734
# FA2 uses non-transposed inputs
2835
query = query.transpose(1, 2)
@@ -52,7 +59,7 @@ def flash_attention_forward(
5259
key,
5360
value,
5461
attention_mask,
55-
query_length=seq_len,
62+
query_length=q_len,
5663
is_causal=module.is_causal,
5764
dropout=dropout,
5865
softmax_scale=scaling,
@@ -62,5 +69,6 @@ def flash_attention_forward(
6269
target_dtype=target_dtype,
6370
**kwargs,
6471
)
72+
# attn_output: (batch, q_len, num_heads, head_dim)
6573

6674
return attn_output, None

src/transformers/integrations/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ def flex_attention_forward(
155155
head_mask: Optional[torch.Tensor] = None,
156156
**kwargs,
157157
) -> Tuple[torch.Tensor, torch.Tensor]:
158+
"""
159+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
160+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
161+
`num_key_value_groups <= num_heads` and
162+
`num_heads % num_key_value_groups == 0`.
163+
164+
"""
158165
block_mask = None
159166
causal_mask = None
160167
if isinstance(attention_mask, BlockMask):

src/transformers/integrations/sdpa_attention.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,41 +25,56 @@ def sdpa_attention_forward(
2525
scaling: Optional[float] = None,
2626
is_causal: Optional[bool] = None,
2727
**kwargs,
28-
) -> Tuple[torch.Tensor, None]:
29-
if hasattr(module, "num_key_value_groups"):
30-
key = repeat_kv(key, module.num_key_value_groups)
31-
value = repeat_kv(value, module.num_key_value_groups)
28+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
29+
"""
30+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
31+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
32+
`num_key_value_groups <= num_heads` and
33+
`num_heads % num_key_value_groups == 0`.
34+
35+
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
36+
`scaled_dot_product_attention` is supposed to support
37+
`num_key_value_groups < num_heads`, if `enable_gqa=True`.
3238
39+
"""
40+
assert query.ndim == key.ndim == value.ndim == 4
41+
_, num_heads, q_len, _ = query.shape
42+
_, num_key_value_groups, kv_len, _ = key.shape
43+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
44+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
45+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
46+
q_per_kv = num_heads // num_key_value_groups
3347
causal_mask = attention_mask
3448
if attention_mask is not None:
35-
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
36-
37-
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
38-
# Reference: https://github.com/pytorch/pytorch/issues/112577.
39-
query = query.contiguous()
40-
key = key.contiguous()
41-
value = value.contiguous()
42-
49+
causal_mask = causal_mask[:, :, :, :kv_len]
4350
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
4451
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
4552
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
4653
if is_causal is None:
47-
is_causal = query.shape[2] > 1 and causal_mask is None
54+
is_causal = causal_mask is None and q_len > 1
4855

4956
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
5057
# We convert it to a bool for the SDPA kernel that only accepts bools.
5158
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
5259
is_causal = is_causal.item()
5360

61+
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
62+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
63+
enable_gqa = q_per_kv > 1
64+
query = query.contiguous()
65+
key = key.contiguous()
66+
value = value.contiguous()
67+
5468
attn_output = torch.nn.functional.scaled_dot_product_attention(
55-
query,
56-
key,
57-
value,
69+
query=query,
70+
key=key,
71+
value=value,
5872
attn_mask=causal_mask,
5973
dropout_p=dropout,
6074
scale=scaling,
6175
is_causal=is_causal,
76+
enable_gqa=enable_gqa,
6277
)
6378
attn_output = attn_output.transpose(1, 2).contiguous()
64-
79+
# attn_output: (batch, q_len, num_heads, head_dim)
6580
return attn_output, None

src/transformers/models/aria/configuration_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@
2222

2323
from ...configuration_utils import PretrainedConfig
2424
from ...modeling_rope_utils import rope_config_validation
25+
from ...utils.import_utils import is_torch_available
2526
from ..auto import CONFIG_MAPPING, AutoConfig
2627

2728

29+
if is_torch_available():
30+
pass
31+
32+
2833
class AriaTextConfig(PretrainedConfig):
2934
r"""
3035
This class handles the configuration for the text component of the Aria model.

src/transformers/models/aria/image_processing_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
validate_preprocess_arguments,
3838
)
3939
from ...utils import TensorType
40+
from ...utils.import_utils import is_torch_available
41+
42+
43+
if is_torch_available():
44+
pass
4045

4146

4247
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:

src/transformers/models/aria/modeling_aria.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...activations import ACT2FN
2525
from ...cache_utils import Cache, DynamicCache, StaticCache
2626
from ...generation import GenerationMixin
27+
from ...integrations.eager_attention import eager_attention_forward
2728
from ...modeling_attn_mask_utils import AttentionMaskConverter
2829
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2930
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
@@ -46,6 +47,7 @@
4647

4748
if is_torch_available():
4849
import torch
50+
import torch.nn.functional as F
4951
from torch import nn
5052

5153

@@ -353,7 +355,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
353355
"""
354356
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
355357
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
356-
fc1_output = nn.functional.silu(projection) * gate
358+
fc1_output = F.silu(projection) * gate
357359
fc2_output = self.fc2(fc1_output, tokens_per_expert)
358360
return fc2_output
359361

@@ -402,7 +404,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402404
# Top K Routing
403405
logits = self.router(hidden_states)
404406
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
405-
scores = nn.functional.softmax(top_logits, dim=-1)
407+
scores = F.softmax(top_logits, dim=-1)
406408

407409
original_dtype = top_indices.dtype
408410

@@ -472,44 +474,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
472474
return q_embed, k_embed
473475

474476

475-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
476-
"""
477-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
478-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
479-
"""
480-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
481-
if n_rep == 1:
482-
return hidden_states
483-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
484-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
485-
486-
487-
def eager_attention_forward(
488-
module: nn.Module,
489-
query: torch.Tensor,
490-
key: torch.Tensor,
491-
value: torch.Tensor,
492-
attention_mask: Optional[torch.Tensor],
493-
scaling: float,
494-
dropout: float = 0.0,
495-
**kwargs,
496-
):
497-
key_states = repeat_kv(key, module.num_key_value_groups)
498-
value_states = repeat_kv(value, module.num_key_value_groups)
499-
500-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
501-
if attention_mask is not None:
502-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
503-
attn_weights = attn_weights + causal_mask
504-
505-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
506-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
507-
attn_output = torch.matmul(attn_weights, value_states)
508-
attn_output = attn_output.transpose(1, 2).contiguous()
509-
510-
return attn_output, attn_weights
511-
512-
513477
class AriaTextAttention(nn.Module):
514478
"""Multi-headed attention from 'Attention Is All You Need' paper"""
515479

@@ -972,6 +936,7 @@ def forward(
972936
use_cache,
973937
cache_position,
974938
position_embeddings,
939+
**flash_attn_kwargs,
975940
)
976941
else:
977942
layer_outputs = decoder_layer(

src/transformers/models/aria/modular_aria.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
if is_torch_available():
6868
import torch
69+
import torch.nn.functional as F
6970
from torch import nn
7071

7172

@@ -1094,7 +1095,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
10941095
"""
10951096
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
10961097
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1097-
fc1_output = nn.functional.silu(projection) * gate
1098+
fc1_output = F.silu(projection) * gate
10981099
fc2_output = self.fc2(fc1_output, tokens_per_expert)
10991100
return fc2_output
11001101

@@ -1143,7 +1144,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
11431144
# Top K Routing
11441145
logits = self.router(hidden_states)
11451146
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
1146-
scores = nn.functional.softmax(top_logits, dim=-1)
1147+
scores = F.softmax(top_logits, dim=-1)
11471148

11481149
original_dtype = top_indices.dtype
11491150

src/transformers/models/aria/processing_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@
2525
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
2626
from ...tokenization_utils import PreTokenizedInput, TextInput
2727
from ...utils import TensorType
28+
from ...utils.import_utils import is_torch_available
2829
from ..auto import AutoTokenizer
2930

3031

32+
if is_torch_available():
33+
pass
34+
35+
3136
class AriaProcessorKwargs(ProcessingKwargs, total=False):
3237
_defaults = {
3338
"text_kwargs": {

0 commit comments

Comments
 (0)