Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] add rms-norm #2396

Merged
merged 2 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
Expand Down Expand Up @@ -68,13 +70,13 @@ def __init__(self,
bias=bias,
)

assert norm in ['batch_norm', 'layer_norm']
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
self.norm = WENET_NORM_CLASSES['batch_norm'](channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = WENET_NORM_CLASSES[norm](channels)

self.pointwise_conv2 = nn.Conv1d(
channels,
Expand Down
6 changes: 5 additions & 1 deletion wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -93,8 +95,10 @@ def __init__(
positional_dropout_rate),
)

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
eps=1e-5)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
Expand Down
10 changes: 7 additions & 3 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class DecoderLayer(nn.Module):
"""Single decoder layer module.
Expand Down Expand Up @@ -46,16 +48,18 @@ def __init__(
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an DecoderLayer object."""
super().__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.norm3 = nn.LayerNorm(size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before

Expand Down
29 changes: 18 additions & 11 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
):
"""
Args:
Expand Down Expand Up @@ -102,8 +104,10 @@ def __init__(
positional_dropout_rate),
)

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
Expand Down Expand Up @@ -368,6 +372,7 @@ def __init__(
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
):
""" Construct TransformerEncoder

Expand All @@ -379,19 +384,21 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
use_sdpa, layer_norm_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias), dropout_rate, normalize_before)
TransformerEncoderLayer(output_size,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, output_size,
attention_dropout_rate, query_bias,
key_bias, value_bias, use_sdpa),
mlp_class(output_size, linear_units,
dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type)
for _ in range(num_blocks)
])

Expand Down
24 changes: 17 additions & 7 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Expand All @@ -44,13 +46,15 @@ def __init__(
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down Expand Up @@ -135,23 +139,29 @@ def __init__(
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
self.norm_ff = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the FNN module
self.norm_mha = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the CNN module
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
Expand Down
22 changes: 22 additions & 0 deletions wenet/transformer/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch


class RMSNorm(torch.nn.Module):
""" https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x = self._norm(x.float()).type_as(x)
return x * self.weight
8 changes: 8 additions & 0 deletions wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# -*- coding: utf-8 -*-
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch
from torch.nn import BatchNorm1d, LayerNorm
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.transformer.norm import RMSNorm
from wenet.transformer.positionwise_feed_forward import (
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)

Expand Down Expand Up @@ -77,3 +79,9 @@
'moe': MoEFFNLayer,
'gated': GatedVariantsMLP
}

WENET_NORM_CLASSES = {
'layer_norm': LayerNorm,
'batch_norm': BatchNorm1d,
'rms_norm': RMSNorm
}
Loading