Skip to content

Improvements in attention_forward functions #36218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ def flash_attention_forward(
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
`num_key_value_groups <= num_heads` and
`num_heads % num_key_value_groups == 0`.

"""
if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
logger.warning_once(
"`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)

# This is before the transpose
seq_len = query.shape[2]
q_len = query.shape[2]

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
Expand Down Expand Up @@ -60,7 +67,7 @@ def flash_attention_forward(
key,
value,
attention_mask,
query_length=seq_len,
query_length=q_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
Expand All @@ -70,5 +77,6 @@ def flash_attention_forward(
target_dtype=target_dtype,
**kwargs,
)
# attn_output: (batch, q_len, num_heads, head_dim)

return attn_output, None
7 changes: 7 additions & 0 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def flex_attention_forward(
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
`num_key_value_groups <= num_heads` and
`num_heads % num_key_value_groups == 0`.

"""
if head_mask is not None:
logger.warning_once(
"`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature."
Expand Down
40 changes: 33 additions & 7 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ def sdpa_attention_forward(
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
`num_key_value_groups <= num_heads` and
`num_heads % num_key_value_groups == 0`.

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
`scaled_dot_product_attention` is supposed to support
`num_key_value_groups < num_heads`, if `enable_gqa=True`.

"""
if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
logger.warning_once(
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
Expand All @@ -41,8 +52,15 @@ def sdpa_attention_forward(
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)

assert query.ndim == key.ndim == value.ndim == 4
_, num_heads, q_len, _ = query.shape
_, num_key_value_groups, kv_len, _ = key.shape
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
q_per_kv = num_heads // num_key_value_groups
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attention_mask = attention_mask[:, :, :, :kv_len]

# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -56,22 +74,30 @@ def sdpa_attention_forward(
if is_causal is None:
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
is_causal = attention_mask is None and q_len > 1 and getattr(module, "is_causal", True)

# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()

# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
enable_gqa = q_per_kv > 1
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
query=query,
key=key,
value=value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
enable_gqa=enable_gqa,
)
attn_output = attn_output.transpose(1, 2).contiguous()

# attn_output: (batch, q_len, num_heads, head_dim)
return attn_output, None
143 changes: 100 additions & 43 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

if is_torch_available():
import torch
import torch.nn.functional as F
from torch import nn


Expand Down Expand Up @@ -341,7 +342,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
"""
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
fc1_output = nn.functional.silu(projection) * gate
fc1_output = F.silu(projection) * gate
fc2_output = self.fc2(fc1_output, tokens_per_expert)
return fc2_output

Expand Down Expand Up @@ -390,7 +391,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Top K Routing
logits = self.router(hidden_states)
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
scores = nn.functional.softmax(top_logits, dim=-1)
scores = F.softmax(top_logits, dim=-1)

original_dtype = top_indices.dtype

Expand Down Expand Up @@ -426,23 +427,117 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output + shared_expert_output


def _attention_compute_scores(
query: torch.Tensor,
key: torch.Tensor,
) -> torch.Tensor:
nh_q = query.shape[1]
nh_k = key.shape[1]
# - query: (bs, nh_q, T_q, hs)
# - key: (bs, nh_k, T_k, hs)
q_per_kv = nh_q // nh_k
key_transposed = key.mT # (bs, nh_k, hs, T_k)
if q_per_kv == 1:
return query @ key_transposed
else:
assert q_per_kv > 1
if nh_k > 1:
q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:]
_query = query.view(*q_shape)
key_transposed = key_transposed.unsqueeze(2)
else:
_query = query
# At this point:
# - _query: (bs, nh_k, q_per_kv, T_q, hs)
# - key_transposed: (bs, nh_k, 1, hs, T_k)
# - scores: (bs, nh_k, q_per_kv, T_q, T_k) -> (bs, nh_q, T_q, T_k)
scores = torch.matmul(_query, key_transposed)
s_shape = query.shape[:-1] + (key.shape[2],)
return scores.view(*s_shape)


def _attention_compute_weighted_values(
scores: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
nh_q = scores.shape[1]
nh_k = value.shape[1]
# - scores: (bs, nh_q, T_q, T_k)
# - value: (bs, nh_k, T_k, hs)
q_per_kv = nh_q // nh_k
if q_per_kv == 1:
return scores @ value
else:
if nh_k > 1:
s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:]
_scores = scores.view(*s_shape)
_value = value.unsqueeze(2)
else:
_scores = scores
_value = value
# At this point:
# - _scores: (bs, nh_k, q_per_kv, T_q, T_k)
# - _value: (bs, nh_k, 1, T_k, hs)
# - result: (bs, nh_k, q_per_kv, T_q, hs) -> (bs, nh_q, T_q, hs)
result = torch.matmul(_scores, _value)
r_shape = scores.shape[:-1] + (value.shape[-1],)
return result.view(*r_shape)


def eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
`query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
`value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
`num_key_value_groups <= num_heads` and
`num_heads % num_key_value_groups == 0`.

"""
assert query.ndim == key.ndim == value.ndim == 4
_, num_heads, q_len, _ = query.shape
_, num_key_value_groups, kv_len, _ = key.shape
assert query.shape[0] == key.shape[0] == value.shape[0] # batch_size
assert value.shape[1] == num_key_value_groups and value.shape[2] == kv_len
assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups

attn_weights = _attention_compute_scores(query, key) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, :kv_len]
attn_weights = attn_weights + causal_mask

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
attn_output = _attention_compute_weighted_values(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
# attn_output: (batch, q_len, num_heads, head_dim)
# attn_weights: (batch, num_heads, q_len, kv_len)

return attn_output, attn_weights


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand All @@ -460,44 +555,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


class AriaTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

if is_torch_available():
import torch
import torch.nn.functional as F
from torch import nn


Expand Down Expand Up @@ -1165,7 +1166,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
"""
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
fc1_output = nn.functional.silu(projection) * gate
fc1_output = F.silu(projection) * gate
fc2_output = self.fc2(fc1_output, tokens_per_expert)
return fc2_output

Expand Down Expand Up @@ -1214,7 +1215,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Top K Routing
logits = self.router(hidden_states)
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
scores = nn.functional.softmax(top_logits, dim=-1)
scores = F.softmax(top_logits, dim=-1)

original_dtype = top_indices.dtype

Expand Down
Loading