Skip to content

Commit

Permalink
support linear bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 23, 2024
1 parent a17fc45 commit 427a033
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 21 deletions.
30 changes: 24 additions & 6 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ def __init__(self,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
bias: bool = True):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_q = nn.Linear(n_feat, n_feat, bias=bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat, bias=bias)
self.linear_out = nn.Linear(n_feat, n_feat, bias=bias)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
Expand Down Expand Up @@ -230,9 +231,15 @@ def __init__(self,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
bias: bool = True):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
super().__init__(n_head,
n_feat,
dropout_rate,
key_bias,
use_sdpa,
bias=bias)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
Expand Down Expand Up @@ -369,3 +376,14 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class MultiQueryAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
12 changes: 9 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -92,7 +93,9 @@ def __init__(
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
self.output_layer = torch.nn.Linear(attention_dim,
vocab_size,
bias=bias)
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks
Expand Down Expand Up @@ -301,6 +304,7 @@ def __init__(
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
):

super().__init__()
Expand All @@ -322,7 +326,8 @@ def __init__(
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
mlp_type=mlp_type)
mlp_type=mlp_type,
bias=bias)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -341,7 +346,8 @@ def __init__(
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
mlp_type=mlp_type)
mlp_type=mlp_type,
bias=bias)

def forward(
self,
Expand Down
17 changes: 12 additions & 5 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def __init__(self,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward'):
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
Expand All @@ -381,10 +382,13 @@ def __init__(self,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias, use_sdpa),
mlp_class(output_size, linear_units, dropout_rate,
activation), dropout_rate, normalize_before)
for _ in range(num_blocks)
key_bias,
use_sdpa,
bias=bias),
mlp_class(output_size, linear_units, dropout_rate, activation),
dropout_rate,
normalize_before,
bias=bias) for _ in range(num_blocks)
])


Expand Down Expand Up @@ -420,6 +424,7 @@ def __init__(
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
bias: bool = True,
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -454,13 +459,15 @@ def __init__(
attention_dropout_rate,
key_bias,
use_sdpa,
bias,
)
# feed-forward module definition
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
bias,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
Expand Down
18 changes: 11 additions & 7 deletions wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def __init__(
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias)

def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Expand Down Expand Up @@ -80,12 +81,14 @@ def __init__(
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
self.experts = torch.nn.ModuleList(
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
activation) for _ in range(n_expert))
PositionwiseFeedForward(
idim, hidden_units, dropout_rate, activation, bias=bias)
for _ in range(n_expert))
self.n_expert_per_token = n_expert_per_token

def forward(self, xs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -125,16 +128,17 @@ def __init__(
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
):
"""Construct a PositionwiseFeedForward object."""
super(GatedVariantsMLP, self).__init__()
self.gate = torch.nn.Linear(idim, hidden_units)
self.gate = torch.nn.Linear(idim, hidden_units, bias=False)
self.activation = activation
# w_1 as up proj
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias)
self.dropout = torch.nn.Dropout(dropout_rate)
# w_2 as down proj
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias)

def forward(self, x):
"""Foward function.
Expand Down

0 comments on commit 427a033

Please sign in to comment.