diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ba31edffc..d43c52164 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -90,6 +90,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super().__init__() attention_dim = encoder_output_size @@ -131,7 +132,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type, @@ -360,6 +362,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal' ): super().__init__() @@ -393,7 +396,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) self.right_decoder = TransformerDecoder( vocab_size, @@ -423,7 +427,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 9cfd260ea..aaa9612ad 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -394,6 +394,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """ Construct TransformerEncoder @@ -423,7 +424,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -474,6 +476,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """Construct ConformerEncoder @@ -522,6 +525,7 @@ def __init__( mlp_bias, n_expert, n_expert_activated, + gate_type, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index e4c38e0f9..e7be32f64 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -15,7 +15,7 @@ """Positionwise feed forward layer definition.""" import torch - +import torch.nn.functional as F class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward layer. @@ -66,6 +66,8 @@ class MoEFFNLayer(torch.nn.Module): Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + + Noisy-gate reference from https://arxiv.org/pdf/1701.06538.pdf Args: n_expert: number of expert. n_expert_activated: The actual number of experts used for each frame @@ -84,6 +86,7 @@ def __init__( bias: bool = False, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super(MoEFFNLayer, self).__init__() self.gate = torch.nn.Linear(idim, n_expert, bias=False) @@ -93,6 +96,9 @@ def __init__( for _ in range(n_expert)) self.n_expert = n_expert self.n_expert_activated = n_expert_activated + self.gate_type = gate_type + if self.gate_type == 'noisy': + self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False) def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -106,6 +112,12 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) + if self.gate_type == 'noisy': + noisy_router = self.noisy_gate(xs) + noisy_router = ( + torch.randn_like(router) * F.softplus(noisy_router) + ) * self.training + router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated)