Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] Add moe_noisy_gate #2495

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -130,7 +131,8 @@ def __init__(
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
n_expert_activated=n_expert_activated,
gate_type=gate_type),
dropout_rate,
normalize_before,
layer_norm_type,
Expand Down Expand Up @@ -352,6 +354,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal'
):

super().__init__()
Expand Down Expand Up @@ -385,7 +388,8 @@ def __init__(
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)
n_expert_activated=n_expert_activated,
gate_type=gate_type)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand Down Expand Up @@ -415,7 +419,8 @@ def __init__(
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)
n_expert_activated=n_expert_activated,
gate_type=gate_type)

def forward(
self,
Expand Down
17 changes: 12 additions & 5 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,18 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs

@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
chunk_masks, pos_emb,
mask_pad, use_reentrant=False)
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__,
xs,
chunk_masks,
pos_emb,
mask_pad,
use_reentrant=False)
return xs

def forward_chunk(
Expand Down Expand Up @@ -391,6 +394,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
""" Construct TransformerEncoder

Expand Down Expand Up @@ -420,7 +424,8 @@ def __init__(
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
n_expert_activated=n_expert_activated,
gate_type=gate_type),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
Expand Down Expand Up @@ -471,6 +476,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
"""Construct ConformerEncoder

Expand Down Expand Up @@ -519,6 +525,7 @@ def __init__(
mlp_bias,
n_expert,
n_expert_activated,
gate_type,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
Expand Down
10 changes: 9 additions & 1 deletion wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Positionwise feed forward layer definition."""

import torch

import torch.nn.functional as F

class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
bias: bool = False,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
Expand All @@ -93,6 +94,9 @@ def __init__(
for _ in range(n_expert))
self.n_expert = n_expert
self.n_expert_activated = n_expert_activated
self.gate_type = gate_type
if self.gate_type == 'noisy':
self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False)

def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Expand All @@ -106,6 +110,10 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
if self.gate_type == 'noisy':
noisy_router = self.noisy_gate(xs)
noisy_router = torch.randn_like(router) * F.softplus(noisy_router)
router = router + noisy_router
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推理阶段也需要吗?我理解这个更像是服务于训练阶段避免有的专家没参与训练的

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推理理论上是不用的,这个可以做个实验测试一下相差多少

logits, selected_experts = torch.topk(
router, self.n_expert_activated
) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated)
Expand Down
Loading