Skip to content

Commit 833c4bf

Browse files
committed
Improvements in attention_forward functions
1 parent d1b9236 commit 833c4bf

File tree

66 files changed

+2618
-987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2618
-987
lines changed

src/transformers/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class PretrainedConfig(PushToHubMixin):
6464
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
6565
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
6666
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
67+
- **all_sub_configs_have_defaults** (`bool`) -- In general, if `is_composition == True`, the config object
68+
must be initialized passing at least one of the sub-configs. But if
69+
`all_sub_configs_have_defaults == True`, all sub-configs have defaults, so the config can be created
70+
without arguments.
6771
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
6872
outputs of the model during inference.
6973
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
@@ -194,6 +198,7 @@ class PretrainedConfig(PushToHubMixin):
194198
base_config_key: str = ""
195199
sub_configs: dict[str, "PretrainedConfig"] = {}
196200
is_composition: bool = False
201+
all_sub_configs_have_defaults: bool = False
197202
attribute_map: dict[str, str] = {}
198203
base_model_tp_plan: Optional[dict[str, Any]] = None
199204
base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None

src/transformers/integrations/flash_attention.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ def flash_attention_forward(
1919
sliding_window: Optional[int] = None,
2020
softcap: Optional[float] = None,
2121
**kwargs,
22-
) -> Tuple[torch.Tensor, None]:
22+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
23+
"""
24+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
25+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
26+
`num_key_value_groups <= num_heads` and
27+
`num_heads % num_key_value_groups == 0`.
28+
29+
"""
2330
# This is before the transpose
24-
seq_len = query.shape[2]
31+
q_len = query.shape[2]
2532

2633
# FA2 uses non-transposed inputs
2734
query = query.transpose(1, 2)
@@ -51,7 +58,7 @@ def flash_attention_forward(
5158
key,
5259
value,
5360
attention_mask,
54-
query_length=seq_len,
61+
query_length=q_len,
5562
is_causal=module.is_causal,
5663
dropout=dropout,
5764
softmax_scale=scaling,
@@ -61,5 +68,6 @@ def flash_attention_forward(
6168
target_dtype=target_dtype,
6269
**kwargs,
6370
)
71+
# attn_output: (batch, q_len, num_heads, head_dim)
6472

6573
return attn_output, None

src/transformers/integrations/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def flex_attention_forward(
188188
head_mask: Optional[torch.Tensor] = None,
189189
**kwargs,
190190
) -> Tuple[torch.Tensor, torch.Tensor]:
191+
"""
192+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
193+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
194+
`num_key_value_groups <= num_heads` and
195+
`num_heads % num_key_value_groups == 0`.
196+
197+
"""
191198
block_mask = None
192199
causal_mask = None
193200
if isinstance(attention_mask, BlockMask):

src/transformers/integrations/sdpa_attention.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,28 @@ def sdpa_attention_forward(
2525
scaling: Optional[float] = None,
2626
is_causal: Optional[bool] = None,
2727
**kwargs,
28-
) -> Tuple[torch.Tensor, None]:
29-
if hasattr(module, "num_key_value_groups"):
30-
key = repeat_kv(key, module.num_key_value_groups)
31-
value = repeat_kv(value, module.num_key_value_groups)
28+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
29+
"""
30+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
31+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
32+
`num_key_value_groups <= num_heads` and
33+
`num_heads % num_key_value_groups == 0`.
34+
35+
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
36+
`scaled_dot_product_attention` is supposed to support
37+
`num_key_value_groups < num_heads`, if `enable_gqa=True`.
3238
39+
"""
40+
assert query.ndim == key.ndim == value.ndim == 4
41+
_, num_heads, q_len, _ = query.shape
42+
_, num_key_value_groups, kv_len, _ = key.shape
43+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
44+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
45+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
46+
q_per_kv = num_heads // num_key_value_groups
3347
causal_mask = attention_mask
3448
if attention_mask is not None and causal_mask.ndim == 4:
35-
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
49+
causal_mask = causal_mask[:, :, :, :kv_len]
3650

3751
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
3852
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -44,22 +58,30 @@ def sdpa_attention_forward(
4458
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
4559
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
4660
if is_causal is None:
47-
is_causal = query.shape[2] > 1 and causal_mask is None
61+
is_causal = causal_mask is None and q_len > 1
4862

4963
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
5064
# We convert it to a bool for the SDPA kernel that only accepts bools.
5165
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
5266
is_causal = is_causal.item()
5367

68+
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
69+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
70+
enable_gqa = q_per_kv > 1
71+
query = query.contiguous()
72+
key = key.contiguous()
73+
value = value.contiguous()
74+
5475
attn_output = torch.nn.functional.scaled_dot_product_attention(
55-
query,
56-
key,
57-
value,
76+
query=query,
77+
key=key,
78+
value=value,
5879
attn_mask=causal_mask,
5980
dropout_p=dropout,
6081
scale=scaling,
6182
is_causal=is_causal,
83+
enable_gqa=enable_gqa,
6284
)
6385
attn_output = attn_output.transpose(1, 2).contiguous()
64-
86+
# attn_output: (batch, q_len, num_heads, head_dim)
6587
return attn_output, None

src/transformers/models/aria/configuration_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@
2222

2323
from ...configuration_utils import PretrainedConfig
2424
from ...modeling_rope_utils import rope_config_validation
25+
from ...utils.import_utils import is_torch_available
2526
from ..auto import CONFIG_MAPPING, AutoConfig
2627

2728

29+
if is_torch_available():
30+
pass
31+
32+
2833
class AriaTextConfig(PretrainedConfig):
2934
r"""
3035
This class handles the configuration for the text component of the Aria model.

src/transformers/models/aria/image_processing_aria.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
validate_preprocess_arguments,
3838
)
3939
from ...utils import TensorType
40+
from ...utils.import_utils import is_torch_available
41+
42+
43+
if is_torch_available():
44+
pass
4045

4146

4247
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:

src/transformers/models/aria/modeling_aria.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
if is_torch_available():
5050
import torch
51+
import torch.nn.functional as F
5152
from torch import nn
5253

5354

@@ -355,7 +356,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
355356
"""
356357
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
357358
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
358-
fc1_output = nn.functional.silu(projection) * gate
359+
fc1_output = F.silu(projection) * gate
359360
fc2_output = self.fc2(fc1_output, tokens_per_expert)
360361
return fc2_output
361362

@@ -404,7 +405,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
404405
# Top K Routing
405406
logits = self.router(hidden_states)
406407
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
407-
scores = nn.functional.softmax(top_logits, dim=-1)
408+
scores = F.softmax(top_logits, dim=-1)
408409

409410
original_dtype = top_indices.dtype
410411

@@ -440,23 +441,117 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
440441
return output + shared_expert_output
441442

442443

444+
def _attention_compute_scores(
445+
query: torch.Tensor,
446+
key: torch.Tensor,
447+
) -> torch.Tensor:
448+
nh_q = query.shape[1]
449+
nh_k = key.shape[1]
450+
# - query: (bs, nh_q, T_q, hs)
451+
# - key: (bs, nh_k, T_k, hs)
452+
q_per_kv = nh_q // nh_k
453+
key_transposed = key.mT # (bs, nh_k, hs, T_k)
454+
if q_per_kv == 1:
455+
return query @ key_transposed
456+
else:
457+
assert q_per_kv > 1
458+
if nh_k > 1:
459+
q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:]
460+
_query = query.view(*q_shape)
461+
key_transposed = key_transposed.unsqueeze(2)
462+
else:
463+
_query = query
464+
# At this point:
465+
# - _query: (bs, nh_k, q_per_kv, T_q, hs)
466+
# - key_transposed: (bs, nh_k, 1, hs, T_k)
467+
# - scores: (bs, nh_k, q_per_kv, T_q, T_k) -> (bs, nh_q, T_q, T_k)
468+
scores = torch.matmul(_query, key_transposed)
469+
s_shape = query.shape[:-1] + (key.shape[2],)
470+
return scores.view(*s_shape)
471+
472+
473+
def _attention_compute_weighted_values(
474+
scores: torch.Tensor,
475+
value: torch.Tensor,
476+
) -> torch.Tensor:
477+
nh_q = scores.shape[1]
478+
nh_k = value.shape[1]
479+
# - scores: (bs, nh_q, T_q, T_k)
480+
# - value: (bs, nh_k, T_k, hs)
481+
q_per_kv = nh_q // nh_k
482+
if q_per_kv == 1:
483+
return scores @ value
484+
else:
485+
if nh_k > 1:
486+
s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:]
487+
_scores = scores.view(*s_shape)
488+
_value = value.unsqueeze(2)
489+
else:
490+
_scores = scores
491+
_value = value
492+
# At this point:
493+
# - _scores: (bs, nh_k, q_per_kv, T_q, T_k)
494+
# - _value: (bs, nh_k, 1, T_k, hs)
495+
# - result: (bs, nh_k, q_per_kv, T_q, hs) -> (bs, nh_q, T_q, hs)
496+
result = torch.matmul(_scores, _value)
497+
r_shape = scores.shape[:-1] + (value.shape[-1],)
498+
return result.view(*r_shape)
499+
500+
501+
def eager_attention_forward(
502+
module: torch.nn.Module,
503+
query: torch.Tensor,
504+
key: torch.Tensor,
505+
value: torch.Tensor,
506+
attention_mask: Optional[torch.Tensor],
507+
scaling: float,
508+
dropout: float = 0.0,
509+
**kwargs,
510+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
511+
"""
512+
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
513+
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
514+
`num_key_value_groups <= num_heads` and
515+
`num_heads % num_key_value_groups == 0`.
516+
517+
"""
518+
assert query.ndim == key.ndim == value.ndim == 4
519+
_, num_heads, q_len, _ = query.shape
520+
_, num_key_value_groups, kv_len, _ = key.shape
521+
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
522+
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
523+
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
524+
525+
attn_weights = _attention_compute_scores(query, key) * scaling
526+
if attention_mask is not None:
527+
causal_mask = attention_mask[:, :, :, :kv_len]
528+
attn_weights = attn_weights + causal_mask
529+
530+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
531+
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
532+
attn_output = _attention_compute_weighted_values(attn_weights, value)
533+
attn_output = attn_output.transpose(1, 2).contiguous()
534+
# attn_output: (batch, q_len, num_heads, head_dim)
535+
# attn_weights: (batch, num_heads, q_len, kv_len)
536+
537+
return attn_output, attn_weights
538+
539+
443540
def rotate_half(x):
444541
"""Rotates half the hidden dims of the input."""
445542
x1 = x[..., : x.shape[-1] // 2]
446543
x2 = x[..., x.shape[-1] // 2 :]
447544
return torch.cat((-x2, x1), dim=-1)
448545

449546

450-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
547+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
451548
"""Applies Rotary Position Embedding to the query and key tensors.
452549
453550
Args:
454551
q (`torch.Tensor`): The query tensor.
455552
k (`torch.Tensor`): The key tensor.
456553
cos (`torch.Tensor`): The cosine part of the rotary embedding.
457554
sin (`torch.Tensor`): The sine part of the rotary embedding.
458-
position_ids (`torch.Tensor`, *optional*):
459-
Deprecated and unused.
460555
unsqueeze_dim (`int`, *optional*, defaults to 1):
461556
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
462557
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -474,44 +569,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
474569
return q_embed, k_embed
475570

476571

477-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
478-
"""
479-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
480-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
481-
"""
482-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
483-
if n_rep == 1:
484-
return hidden_states
485-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
486-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
487-
488-
489-
def eager_attention_forward(
490-
module: nn.Module,
491-
query: torch.Tensor,
492-
key: torch.Tensor,
493-
value: torch.Tensor,
494-
attention_mask: Optional[torch.Tensor],
495-
scaling: float,
496-
dropout: float = 0.0,
497-
**kwargs,
498-
):
499-
key_states = repeat_kv(key, module.num_key_value_groups)
500-
value_states = repeat_kv(value, module.num_key_value_groups)
501-
502-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
503-
if attention_mask is not None:
504-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
505-
attn_weights = attn_weights + causal_mask
506-
507-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
508-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
509-
attn_output = torch.matmul(attn_weights, value_states)
510-
attn_output = attn_output.transpose(1, 2).contiguous()
511-
512-
return attn_output, attn_weights
513-
514-
515572
class AriaTextAttention(nn.Module):
516573
"""Multi-headed attention from 'Attention Is All You Need' paper"""
517574

@@ -946,6 +1003,7 @@ def forward(
9461003
use_cache,
9471004
cache_position,
9481005
position_embeddings,
1006+
**flash_attn_kwargs,
9491007
)
9501008
else:
9511009
layer_outputs = decoder_layer(

src/transformers/models/aria/modular_aria.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868

6969
if is_torch_available():
7070
import torch
71+
import torch.nn.functional as F
7172
from torch import nn
7273

7374

@@ -1096,7 +1097,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
10961097
"""
10971098
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
10981099
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1099-
fc1_output = nn.functional.silu(projection) * gate
1100+
fc1_output = F.silu(projection) * gate
11001101
fc2_output = self.fc2(fc1_output, tokens_per_expert)
11011102
return fc2_output
11021103

@@ -1145,7 +1146,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
11451146
# Top K Routing
11461147
logits = self.router(hidden_states)
11471148
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
1148-
scores = nn.functional.softmax(top_logits, dim=-1)
1149+
scores = F.softmax(top_logits, dim=-1)
11491150

11501151
original_dtype = top_indices.dtype
11511152

0 commit comments

Comments
 (0)