From 906832b0fc21b25c268b05ea374ad1d542fa941d Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 18 Apr 2024 15:37:36 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=A2=9E=E5=8A=A0noisy=5Fgate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wenet/transformer/decoder.py | 11 ++++++++--- wenet/transformer/encoder.py | 19 +++++++++++++------ .../transformer/positionwise_feed_forward.py | 10 +++++++++- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index f674cea284..401e4bd536 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -89,6 +89,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super().__init__() attention_dim = encoder_output_size @@ -130,7 +131,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type, @@ -352,6 +354,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal' ): super().__init__() @@ -385,7 +388,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) self.right_decoder = TransformerDecoder( vocab_size, @@ -415,7 +419,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 26cb0ef8ca..e572ef70f4 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -184,15 +184,18 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: - xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, - chunk_masks, pos_emb, - mask_pad, use_reentrant=False) + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs def forward_chunk( @@ -391,6 +394,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """ Construct TransformerEncoder @@ -420,7 +424,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -471,6 +476,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """Construct ConformerEncoder @@ -519,6 +525,7 @@ def __init__( mlp_bias, n_expert, n_expert_activated, + gate_type, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, @@ -539,4 +546,4 @@ def __init__( layer_norm_type=layer_norm_type, norm_eps=norm_eps, ) for _ in range(num_blocks) - ]) + ]) \ No newline at end of file diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index e4c38e0f99..3ea2312d1c 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -15,7 +15,7 @@ """Positionwise feed forward layer definition.""" import torch - +import torch.nn.functional as F class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward layer. @@ -84,6 +84,7 @@ def __init__( bias: bool = False, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super(MoEFFNLayer, self).__init__() self.gate = torch.nn.Linear(idim, n_expert, bias=False) @@ -93,6 +94,9 @@ def __init__( for _ in range(n_expert)) self.n_expert = n_expert self.n_expert_activated = n_expert_activated + self.gate_type= gate_type + if self.gate_type == 'noisy': + self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False) def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -106,6 +110,10 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) + if self.gate_type == 'noisy': + noisy_router = self.noisy_gate(xs) + noisy_router = torch.randn_like(router)*F.softplus(noisy_router) + router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated) From bc6fe955666ca12da6f19cdd801564762b626302 Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 18 Apr 2024 16:06:25 +0800 Subject: [PATCH 2/7] add noisy_gate --- wenet/transformer/encoder.py | 3 ++- wenet/transformer/positionwise_feed_forward.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index e572ef70f4..20874de57b 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -546,4 +546,5 @@ def __init__( layer_norm_type=layer_norm_type, norm_eps=norm_eps, ) for _ in range(num_blocks) - ]) \ No newline at end of file + ]) + \ No newline at end of file diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 3ea2312d1c..dac1530c4a 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -94,7 +94,7 @@ def __init__( for _ in range(n_expert)) self.n_expert = n_expert self.n_expert_activated = n_expert_activated - self.gate_type= gate_type + self.gate_type = gate_type if self.gate_type == 'noisy': self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False) @@ -112,7 +112,7 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: router = self.gate(xs) # (B*L, n_expert) if self.gate_type == 'noisy': noisy_router = self.noisy_gate(xs) - noisy_router = torch.randn_like(router)*F.softplus(noisy_router) + noisy_router = torch.randn_like(router) * F.softplus(noisy_router) router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated From 1056712ca6790472ea03adf815ac0515c136b416 Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 18 Apr 2024 16:09:31 +0800 Subject: [PATCH 3/7] add noisy_gate --- wenet/transformer/encoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 20874de57b..aaa9612adc 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -547,4 +547,3 @@ def __init__( norm_eps=norm_eps, ) for _ in range(num_blocks) ]) - \ No newline at end of file From 8a114bdad400d18c1d52ea79ca41810c94e1e3ef Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 18 Apr 2024 17:47:29 +0800 Subject: [PATCH 4/7] add noisy_gate --- wenet/transformer/decoder.py | 11 ++++++++--- wenet/transformer/encoder.py | 17 ++++++++++++----- wenet/transformer/positionwise_feed_forward.py | 10 +++++++++- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index f674cea284..401e4bd536 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -89,6 +89,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super().__init__() attention_dim = encoder_output_size @@ -130,7 +131,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type, @@ -352,6 +354,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal' ): super().__init__() @@ -385,7 +388,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) self.right_decoder = TransformerDecoder( vocab_size, @@ -415,7 +419,8 @@ def __init__( mlp_type=mlp_type, mlp_bias=mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated) + n_expert_activated=n_expert_activated, + gate_type=gate_type) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 26cb0ef8ca..aaa9612adc 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -184,15 +184,18 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: - xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, - chunk_masks, pos_emb, - mask_pad, use_reentrant=False) + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs def forward_chunk( @@ -391,6 +394,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """ Construct TransformerEncoder @@ -420,7 +424,8 @@ def __init__( activation, mlp_bias, n_expert=n_expert, - n_expert_activated=n_expert_activated), + n_expert_activated=n_expert_activated, + gate_type=gate_type), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -471,6 +476,7 @@ def __init__( mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): """Construct ConformerEncoder @@ -519,6 +525,7 @@ def __init__( mlp_bias, n_expert, n_expert_activated, + gate_type, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index e4c38e0f99..dac1530c4a 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -15,7 +15,7 @@ """Positionwise feed forward layer definition.""" import torch - +import torch.nn.functional as F class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward layer. @@ -84,6 +84,7 @@ def __init__( bias: bool = False, n_expert: int = 8, n_expert_activated: int = 2, + gate_type: str = 'normal', ): super(MoEFFNLayer, self).__init__() self.gate = torch.nn.Linear(idim, n_expert, bias=False) @@ -93,6 +94,9 @@ def __init__( for _ in range(n_expert)) self.n_expert = n_expert self.n_expert_activated = n_expert_activated + self.gate_type = gate_type + if self.gate_type == 'noisy': + self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False) def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -106,6 +110,10 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) + if self.gate_type == 'noisy': + noisy_router = self.noisy_gate(xs) + noisy_router = torch.randn_like(router) * F.softplus(noisy_router) + router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated) From dafe8c058fa80265de7a89c1938d98373cb0ccfc Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 2 May 2024 23:05:09 +0800 Subject: [PATCH 5/7] only training noisy --- wenet/transformer/positionwise_feed_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index dac1530c4a..716f266145 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -66,6 +66,8 @@ class MoEFFNLayer(torch.nn.Module): Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + + Noisy-gate reference from https://arxiv.org/pdf/1701.06538.pdf Args: n_expert: number of expert. n_expert_activated: The actual number of experts used for each frame @@ -112,7 +114,7 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: router = self.gate(xs) # (B*L, n_expert) if self.gate_type == 'noisy': noisy_router = self.noisy_gate(xs) - noisy_router = torch.randn_like(router) * F.softplus(noisy_router) + noisy_router = torch.randn_like(router) * F.softplus(noisy_router) * self.training router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated From f5eeb5cfe90c5da9421bb1e7e0cc4956838a1539 Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 2 May 2024 23:08:50 +0800 Subject: [PATCH 6/7] fix length --- wenet/transformer/positionwise_feed_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 716f266145..9931e21ce3 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -114,7 +114,9 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: router = self.gate(xs) # (B*L, n_expert) if self.gate_type == 'noisy': noisy_router = self.noisy_gate(xs) - noisy_router = torch.randn_like(router) * F.softplus(noisy_router) * self.training + noisy_router = ( + torch.randn_like(router) * F.softplus(noisy_router) + ) * self.training router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated From 954aa10cc024663e24a6e08b1479ca976d1e8952 Mon Sep 17 00:00:00 2001 From: llleohk Date: Thu, 2 May 2024 23:13:18 +0800 Subject: [PATCH 7/7] fix length --- wenet/transformer/positionwise_feed_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 9931e21ce3..e7be32f644 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -116,7 +116,7 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: noisy_router = self.noisy_gate(xs) noisy_router = ( torch.randn_like(router) * F.softplus(noisy_router) - ) * self.training + ) * self.training router = router + noisy_router logits, selected_experts = torch.topk( router, self.n_expert_activated