48
48
49
49
if is_torch_available ():
50
50
import torch
51
+ import torch .nn .functional as F
51
52
from torch import nn
52
53
53
54
@@ -355,7 +356,7 @@ def forward(self, permuted_tokens, tokens_per_expert):
355
356
"""
356
357
fc1_output = self .fc1 (permuted_tokens , tokens_per_expert )
357
358
projection , gate = torch .chunk (fc1_output , 2 , dim = - 1 )
358
- fc1_output = nn . functional .silu (projection ) * gate
359
+ fc1_output = F .silu (projection ) * gate
359
360
fc2_output = self .fc2 (fc1_output , tokens_per_expert )
360
361
return fc2_output
361
362
@@ -404,7 +405,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
404
405
# Top K Routing
405
406
logits = self .router (hidden_states )
406
407
top_logits , top_indices = torch .topk (logits , k = self .config .moe_topk , dim = 1 )
407
- scores = nn . functional .softmax (top_logits , dim = - 1 )
408
+ scores = F .softmax (top_logits , dim = - 1 )
408
409
409
410
original_dtype = top_indices .dtype
410
411
@@ -440,23 +441,117 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
440
441
return output + shared_expert_output
441
442
442
443
444
+ def _attention_compute_scores (
445
+ query : torch .Tensor ,
446
+ key : torch .Tensor ,
447
+ ) -> torch .Tensor :
448
+ nh_q = query .shape [1 ]
449
+ nh_k = key .shape [1 ]
450
+ # - query: (bs, nh_q, T_q, hs)
451
+ # - key: (bs, nh_k, T_k, hs)
452
+ q_per_kv = nh_q // nh_k
453
+ key_transposed = key .mT # (bs, nh_k, hs, T_k)
454
+ if q_per_kv == 1 :
455
+ return query @ key_transposed
456
+ else :
457
+ assert q_per_kv > 1
458
+ if nh_k > 1 :
459
+ q_shape = query .shape [:1 ] + (nh_k , q_per_kv ) + query .shape [2 :]
460
+ _query = query .view (* q_shape )
461
+ key_transposed = key_transposed .unsqueeze (2 )
462
+ else :
463
+ _query = query
464
+ # At this point:
465
+ # - _query: (bs, nh_k, q_per_kv, T_q, hs)
466
+ # - key_transposed: (bs, nh_k, 1, hs, T_k)
467
+ # - scores: (bs, nh_k, q_per_kv, T_q, T_k) -> (bs, nh_q, T_q, T_k)
468
+ scores = torch .matmul (_query , key_transposed )
469
+ s_shape = query .shape [:- 1 ] + (key .shape [2 ],)
470
+ return scores .view (* s_shape )
471
+
472
+
473
+ def _attention_compute_weighted_values (
474
+ scores : torch .Tensor ,
475
+ value : torch .Tensor ,
476
+ ) -> torch .Tensor :
477
+ nh_q = scores .shape [1 ]
478
+ nh_k = value .shape [1 ]
479
+ # - scores: (bs, nh_q, T_q, T_k)
480
+ # - value: (bs, nh_k, T_k, hs)
481
+ q_per_kv = nh_q // nh_k
482
+ if q_per_kv == 1 :
483
+ return scores @ value
484
+ else :
485
+ if nh_k > 1 :
486
+ s_shape = scores .shape [:1 ] + (nh_k , q_per_kv ) + scores .shape [2 :]
487
+ _scores = scores .view (* s_shape )
488
+ _value = value .unsqueeze (2 )
489
+ else :
490
+ _scores = scores
491
+ _value = value
492
+ # At this point:
493
+ # - _scores: (bs, nh_k, q_per_kv, T_q, T_k)
494
+ # - _value: (bs, nh_k, 1, T_k, hs)
495
+ # - result: (bs, nh_k, q_per_kv, T_q, hs) -> (bs, nh_q, T_q, hs)
496
+ result = torch .matmul (_scores , _value )
497
+ r_shape = scores .shape [:- 1 ] + (value .shape [- 1 ],)
498
+ return result .view (* r_shape )
499
+
500
+
501
+ def eager_attention_forward (
502
+ module : torch .nn .Module ,
503
+ query : torch .Tensor ,
504
+ key : torch .Tensor ,
505
+ value : torch .Tensor ,
506
+ attention_mask : Optional [torch .Tensor ],
507
+ scaling : float ,
508
+ dropout : float = 0.0 ,
509
+ ** kwargs ,
510
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
511
+ """
512
+ `query` has shape `(batch, num_heads, q_len, head_dim)`, while `key`,
513
+ `value` have shape `(batch, num_key_value_groups, kv_len, head_dim)`. Here,
514
+ `num_key_value_groups <= num_heads` and
515
+ `num_heads % num_key_value_groups == 0`.
516
+
517
+ """
518
+ assert query .ndim == key .ndim == value .ndim == 4
519
+ _ , num_heads , q_len , _ = query .shape
520
+ _ , num_key_value_groups , kv_len , _ = key .shape
521
+ assert query .shape [0 ] == key .shape [0 ] == value .shape [0 ] # batch_size
522
+ assert value .shape [1 ] == num_key_value_groups and value .shape [2 ] == kv_len
523
+ assert num_heads % num_key_value_groups == 0 and num_heads >= num_key_value_groups
524
+
525
+ attn_weights = _attention_compute_scores (query , key ) * scaling
526
+ if attention_mask is not None :
527
+ causal_mask = attention_mask [:, :, :, :kv_len ]
528
+ attn_weights = attn_weights + causal_mask
529
+
530
+ attn_weights = F .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
531
+ attn_weights = F .dropout (attn_weights , p = dropout , training = module .training )
532
+ attn_output = _attention_compute_weighted_values (attn_weights , value )
533
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
534
+ # attn_output: (batch, q_len, num_heads, head_dim)
535
+ # attn_weights: (batch, num_heads, q_len, kv_len)
536
+
537
+ return attn_output , attn_weights
538
+
539
+
443
540
def rotate_half (x ):
444
541
"""Rotates half the hidden dims of the input."""
445
542
x1 = x [..., : x .shape [- 1 ] // 2 ]
446
543
x2 = x [..., x .shape [- 1 ] // 2 :]
447
544
return torch .cat ((- x2 , x1 ), dim = - 1 )
448
545
449
546
450
- def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
547
+ def apply_rotary_pos_emb (q , k , cos , sin , unsqueeze_dim = 1 ):
451
548
"""Applies Rotary Position Embedding to the query and key tensors.
452
549
453
550
Args:
454
551
q (`torch.Tensor`): The query tensor.
455
552
k (`torch.Tensor`): The key tensor.
456
553
cos (`torch.Tensor`): The cosine part of the rotary embedding.
457
554
sin (`torch.Tensor`): The sine part of the rotary embedding.
458
- position_ids (`torch.Tensor`, *optional*):
459
- Deprecated and unused.
460
555
unsqueeze_dim (`int`, *optional*, defaults to 1):
461
556
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
462
557
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -474,44 +569,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
474
569
return q_embed , k_embed
475
570
476
571
477
- def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
478
- """
479
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
480
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
481
- """
482
- batch , num_key_value_heads , slen , head_dim = hidden_states .shape
483
- if n_rep == 1 :
484
- return hidden_states
485
- hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
486
- return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
487
-
488
-
489
- def eager_attention_forward (
490
- module : nn .Module ,
491
- query : torch .Tensor ,
492
- key : torch .Tensor ,
493
- value : torch .Tensor ,
494
- attention_mask : Optional [torch .Tensor ],
495
- scaling : float ,
496
- dropout : float = 0.0 ,
497
- ** kwargs ,
498
- ):
499
- key_states = repeat_kv (key , module .num_key_value_groups )
500
- value_states = repeat_kv (value , module .num_key_value_groups )
501
-
502
- attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
503
- if attention_mask is not None :
504
- causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
505
- attn_weights = attn_weights + causal_mask
506
-
507
- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
508
- attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
509
- attn_output = torch .matmul (attn_weights , value_states )
510
- attn_output = attn_output .transpose (1 , 2 ).contiguous ()
511
-
512
- return attn_output , attn_weights
513
-
514
-
515
572
class AriaTextAttention (nn .Module ):
516
573
"""Multi-headed attention from 'Attention Is All You Need' paper"""
517
574
@@ -946,6 +1003,7 @@ def forward(
946
1003
use_cache ,
947
1004
cache_position ,
948
1005
position_embeddings ,
1006
+ ** flash_attn_kwargs ,
949
1007
)
950
1008
else :
951
1009
layer_outputs = decoder_layer (
0 commit comments