Skip to content

Commit

Permalink
support multi query attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 27, 2024
1 parent abc1c48 commit 9046622
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 48 deletions.
119 changes: 78 additions & 41 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Multi-Head Attention layer definition."""

import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -26,6 +26,8 @@

class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
if n_kv_head != None and n_kv_head != n_head
see: https://arxiv.org/pdf/1911.02150.pdf
Args:
n_head (int): The number of heads.
Expand All @@ -34,23 +36,40 @@ class MultiHeadedAttention(nn.Module):
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True):
def __init__(
self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0

self.inner_dim = n_feat if head_dim is None else head_dim * n_head
if n_kv_head is not None:
assert head_dim is not None
self.inner_kv_dim = head_dim * n_head
n_kv_head = n_kv_head
else:
self.inner_kv_dim = self.inner_dim
n_kv_head = n_head
if self.inner_dim == n_feat:
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.d_k = self.inner_dim // n_head
assert self.d_k == self.inner_kv_dim // n_kv_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat, bias=bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat, bias=bias)
self.linear_out = nn.Linear(n_feat, n_feat, bias=bias)
self.h_kv = n_head if n_kv_head is None else n_kv_head

self.linear_q = nn.Linear(n_feat, self.inner_dim, bias=bias)
self.linear_k = nn.Linear(n_feat, self.inner_kv_dim, bias=key_bias)
self.linear_v = nn.Linear(n_feat, self.inner_kv_dim, bias=bias)
self.linear_out = nn.Linear(self.inner_dim, n_feat, bias=bias)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
Expand All @@ -70,18 +89,18 @@ def forward_qkv(
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
(#batch, n_kv_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
(#batch, n_kv_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h_kv, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h_kv, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
k = k.transpose(1, 2) # (batch, head_kv, time2, d_k)
v = v.transpose(1, 2) # (batch, head_kv, time2, d_k)

return q, k, v

Expand Down Expand Up @@ -198,6 +217,17 @@ def forward(
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
Expand Down Expand Up @@ -226,22 +256,28 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True):
def __init__(
self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head,
n_feat,
dropout_rate,
key_bias,
use_sdpa,
n_kv_head=n_kv_head,
key_bias=key_bias,
use_sdpa=use_sdpa,
head_dim=head_dim,
bias=bias)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
self.linear_pos = nn.Linear(n_feat, self.inner_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
Expand Down Expand Up @@ -327,6 +363,18 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
Expand Down Expand Up @@ -376,14 +424,3 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class MultiQueryAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
12 changes: 12 additions & 0 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -114,6 +116,8 @@ def __init__(
key_bias,
use_sdpa,
bias=bias,
n_kv_head=n_kv_head,
head_dim=head_dim,
),
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads,
Expand All @@ -122,6 +126,8 @@ def __init__(
key_bias,
use_sdpa,
bias=bias,
n_kv_head=n_kv_head,
head_dim=head_dim,
) if src_attention else None,
mlp_class(
attention_dim,
Expand Down Expand Up @@ -328,6 +334,8 @@ def __init__(
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):

super().__init__()
Expand All @@ -353,6 +361,8 @@ def __init__(
bias=bias,
layer_norm_type=layer_norm_type,
eps=eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

self.right_decoder = TransformerDecoder(
Expand All @@ -376,6 +386,8 @@ def __init__(
bias=bias,
layer_norm_type=layer_norm_type,
eps=eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

def forward(
Expand Down
24 changes: 17 additions & 7 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -371,6 +371,8 @@ def __init__(
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
""" Construct TransformerEncoder
Expand All @@ -388,12 +390,16 @@ def __init__(
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias,
use_sdpa,
bias=bias),
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads,
output_size,
attention_dropout_rate,
key_bias=key_bias,
use_sdpa=use_sdpa,
bias=bias,
n_kv_head=n_kv_head,
head_dim=head_dim,
),
mlp_class(output_size,
linear_units,
dropout_rate,
Expand Down Expand Up @@ -442,6 +448,8 @@ def __init__(
bias: bool = True,
layer_norm_type: str = 'layer_norm',
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -491,6 +499,8 @@ def __init__(
key_bias,
use_sdpa,
bias,
n_kv_head,
head_dim,
)
# feed-forward module definition
positionwise_layer_args = (
Expand Down

0 comments on commit 9046622

Please sign in to comment.