|
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 | from torch import nn
|
| 23 | +from wenet.transformer.embedding import apply_rotary_emb |
23 | 24 |
|
24 | 25 | from wenet.utils.common import get_dtype_min
|
25 | 26 |
|
@@ -424,3 +425,101 @@ def forward(
|
424 | 425 | query.size(0), -1,
|
425 | 426 | self.h * self.d_k)) # (batch, time1, d_model)
|
426 | 427 | return self.linear_out(output), new_cache
|
| 428 | + |
| 429 | + |
| 430 | +class RopeMultiHeadedAttention(MultiHeadedAttention): |
| 431 | + |
| 432 | + def __init__(self, |
| 433 | + n_head: int, |
| 434 | + n_feat: int, |
| 435 | + dropout_rate: float, |
| 436 | + key_bias: bool = True, |
| 437 | + use_sdpa: bool = False, |
| 438 | + bias: bool = True, |
| 439 | + n_kv_head: Optional[int] = None, |
| 440 | + head_dim: Optional[int] = None): |
| 441 | + super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa, |
| 442 | + bias, n_kv_head, head_dim) |
| 443 | + |
| 444 | + def forward( |
| 445 | + self, |
| 446 | + query: torch.Tensor, |
| 447 | + key: torch.Tensor, |
| 448 | + value: torch.Tensor, |
| 449 | + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| 450 | + pos_emb: torch.Tensor = torch.empty(0), |
| 451 | + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) |
| 452 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 453 | + """Compute scaled dot product attention. |
| 454 | +
|
| 455 | + Args: |
| 456 | + query (torch.Tensor): Query tensor (#batch, time1, size). |
| 457 | + key (torch.Tensor): Key tensor (#batch, time2, size). |
| 458 | + value (torch.Tensor): Value tensor (#batch, time2, size). |
| 459 | + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| 460 | + (#batch, time1, time2). |
| 461 | + 1.When applying cross attention between decoder and encoder, |
| 462 | + the batch padding mask for input is in (#batch, 1, T) shape. |
| 463 | + 2.When applying self attention of encoder, |
| 464 | + the mask is in (#batch, T, T) shape. |
| 465 | + 3.When applying self attention of decoder, |
| 466 | + the mask is in (#batch, L, L) shape. |
| 467 | + 4.If the different position in decoder see different block |
| 468 | + of the encoder, such as Mocha, the passed in mask could be |
| 469 | + in (#batch, L, T) shape. But there is no such case in current |
| 470 | + Wenet. |
| 471 | + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), |
| 472 | + where `cache_t == chunk_size * num_decoding_left_chunks` |
| 473 | + and `head * d_k == size` |
| 474 | +
|
| 475 | +
|
| 476 | + Returns: |
| 477 | + torch.Tensor: Output tensor (#batch, time1, d_model). |
| 478 | + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) |
| 479 | + where `cache_t == chunk_size * num_decoding_left_chunks` |
| 480 | + and `head * d_k == size` |
| 481 | +
|
| 482 | + """ |
| 483 | + q, k, v = self.forward_qkv(query, key, value) |
| 484 | + # see above |
| 485 | + if cache.size(0) > 0: |
| 486 | + key_cache, value_cache = torch.split(cache, |
| 487 | + cache.size(-1) // 2, |
| 488 | + dim=-1) |
| 489 | + k = torch.cat([key_cache, k], dim=2) |
| 490 | + v = torch.cat([value_cache, v], dim=2) |
| 491 | + |
| 492 | + # NOTE(Mddct): In order to make the code easier to read, |
| 493 | + # these two lines are not placed in MultiHeadedAttention. |
| 494 | + q = apply_rotary_emb(q, freqs_cis=pos_emb) |
| 495 | + k = apply_rotary_emb(k, freqs_cis=pos_emb) |
| 496 | + |
| 497 | + new_cache = torch.cat((k, v), dim=-1) |
| 498 | + if self.h_kv != self.h: |
| 499 | + k = torch.repeat_interleave( |
| 500 | + k, |
| 501 | + self.h // self.h_kv, |
| 502 | + dim=1, |
| 503 | + ) |
| 504 | + v = torch.repeat_interleave( |
| 505 | + v, |
| 506 | + self.h // self.h_kv, |
| 507 | + dim=1, |
| 508 | + ) |
| 509 | + |
| 510 | + if not self.use_sdpa: |
| 511 | + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| 512 | + return self.forward_attention(v, scores, mask), new_cache |
| 513 | + else: |
| 514 | + output = torch.nn.functional.scaled_dot_product_attention( |
| 515 | + q, |
| 516 | + k, |
| 517 | + v, |
| 518 | + attn_mask=mask.unsqueeze(1), |
| 519 | + dropout_p=self.dropout_rate, |
| 520 | + scale=1 / math.sqrt(self.d_k), |
| 521 | + ) |
| 522 | + output = (output.transpose(1, 2).contiguous().view( |
| 523 | + query.size(0), -1, |
| 524 | + self.h * self.d_k)) # (batch, time1, d_model) |
| 525 | + return self.linear_out(output), new_cache |
0 commit comments