Skip to content

Commit

Permalink
fix init rope attention and rope
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 1, 2024
1 parent 0dc48f1 commit 522a60a
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 35 deletions.
57 changes: 38 additions & 19 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional, Union

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -82,17 +82,21 @@ def __init__(
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
):
assert selfattention_layer_type in ['selfattn', 'rope_selfattn']
super().__init__()
attention_dim = encoder_output_size
activation = WENET_ACTIVATION_CLASSES[activation_type]()

pos_emb_class = WENET_EMB_CLASSES[input_layer]
self.embed = torch.nn.Sequential(
torch.nn.Identity() if input_layer == "no_pos" else
torch.nn.Embedding(vocab_size, attention_dim),
WENET_EMB_CLASSES[input_layer](attention_dim,
positional_dropout_rate),
)
pos_emb_class(attention_dim, positional_dropout_rate)
if input_layer != 'rope' else pos_emb_class(
attention_dim, attention_dim //
attention_heads, positional_dropout_rate))

self.normalize_before = normalize_before
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
Expand All @@ -105,11 +109,12 @@ def __init__(
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
WENET_ATTENTION_CLASSES["selfattn"](
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
attention_dim,
self_attention_dropout_rate,
Expand All @@ -119,7 +124,7 @@ def __init__(
n_kv_head=n_kv_head,
head_dim=head_dim,
),
WENET_ATTENTION_CLASSES["selfattn"](
WENET_ATTENTION_CLASSES['selfattn'](
attention_heads,
attention_dim,
src_attention_dropout_rate,
Expand Down Expand Up @@ -191,35 +196,44 @@ def forward(
tgt_mask = mask_to_bias(tgt_mask, tgt.dtype)
memory_mask = mask_to_bias(memory_mask, memory_mask.dtype)

x, _ = self.embed(tgt)
x, pos_emb = self.embed(tgt)
if self.gradient_checkpointing and self.training:
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
memory_mask)
memory_mask, pos_emb)
else:
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
x = self.forward_layers(x, tgt_mask, memory, memory_mask, pos_emb)
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, torch.tensor(0.0), olens

def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
def forward_layers(
self,
x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
pos_emb: torch.Tensor = torch.empty(0),
) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
memory_mask, pos_emb)
return x

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
def forward_layers_checkpointed(
self,
x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
pos_emb: torch.Tensor = torch.empty(0),
) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
layer.__call__, x, tgt_mask, memory, memory_mask)
layer.__call__, x, tgt_mask, memory, memory_mask, pos_emb)
return x

def forward_one_step(
Expand All @@ -229,6 +243,7 @@ def forward_one_step(
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
offset: Union[int, torch.Tensor] = 0,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
This is only used for decoding.
Expand All @@ -244,7 +259,7 @@ def forward_one_step(
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
x, pos_emb = self.embed(tgt, offset)
new_cache = []
for i, decoder in enumerate(self.decoders):
if cache is None:
Expand All @@ -255,6 +270,7 @@ def forward_one_step(
tgt_mask,
memory,
memory_mask,
pos_emb,
cache=c)
new_cache.append(x)
if self.normalize_before:
Expand Down Expand Up @@ -336,6 +352,7 @@ def __init__(
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
):

super().__init__()
Expand Down Expand Up @@ -363,6 +380,7 @@ def __init__(
eps=eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
selfattention_layer_type=selfattention_layer_type,
)

self.right_decoder = TransformerDecoder(
Expand All @@ -388,6 +406,7 @@ def __init__(
eps=eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
selfattention_layer_type=selfattention_layer_type,
)

def forward(
Expand Down
3 changes: 2 additions & 1 deletion wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def forward(
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
pos_emb: torch.Tensor = torch.empty(0),
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Expand Down Expand Up @@ -110,7 +111,7 @@ def forward(
tgt_q_mask = tgt_mask[:, -1:, :]

x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0])
if not self.normalize_before:
x = self.norm1(x)

Expand Down
17 changes: 12 additions & 5 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,29 +209,32 @@ def precompute_freqs_cis(dim: int,
return freqs_cis


# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95
# modified from:
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1)
return x_out


class RopePositionalEncoding(PositionalEncoding):

def __init__(self,
d_model: int,
pos_dim: int,
dropout_rate: float,
max_len: int = 1500,
rope_theta=10000.0):
# NOTE(Mddct): pos_dim == attention_dim // attention_head
super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len)
delattr(self, 'pe')
self.pe = precompute_freqs_cis(d_model, max_len * 2, rope_theta)
self.pe = precompute_freqs_cis(pos_dim, max_len * 2, rope_theta)
self.dropout_rate = dropout_rate
self.expand = False

def forward(
self,
Expand All @@ -240,7 +243,11 @@ def forward(
torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:

self.pe = self.pe.to(x.device)
if not self.expand:
self.pe = self.pe.unsqueeze(0)
self.expand = True
pos_emb = self.position_encoding(offset, x.size(1), False)
pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2]
# NOTE(Mddct): some model don't scale
# TODO(Mddct): fix
x = x * self.xscale
Expand Down
16 changes: 9 additions & 7 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def __init__(
self._output_size = output_size

self.global_cmvn = global_cmvn
pos_emb_class = WENET_EMB_CLASSES[pos_enc_layer_type]
self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
input_size, output_size, dropout_rate,
pos_emb_class(output_size, positional_dropout_rate)
if pos_enc_layer_type != 'rope' else pos_emb_class(
output_size, output_size //
attention_heads, positional_dropout_rate))

self.normalize_before = normalize_before
assert layer_norm_type in ['layer_norm', 'rms_norm']
Expand Down Expand Up @@ -373,6 +373,7 @@ def __init__(
eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
):
""" Construct TransformerEncoder
Expand All @@ -385,12 +386,13 @@ def __init__(
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, eps)
assert selfattention_layer_type in ['selfattn', 'rope_selfattn']
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
output_size,
attention_dropout_rate,
Expand Down
7 changes: 6 additions & 1 deletion wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def forward(
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x_att, new_att_cache = self.self_attn(x,
x,
x,
mask,
pos_emb,
cache=att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
Expand Down
8 changes: 6 additions & 2 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,12 @@ def attention_beam_search(
if model.decoder.use_sdpa:
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype)
# logp: (B*N, vocab)
logp, cache = model.decoder.forward_one_step(encoder_out, encoder_mask,
hyps, hyps_mask, cache)
logp, cache = model.decoder.forward_one_step(encoder_out,
encoder_mask,
hyps,
hyps_mask,
cache,
offset=i)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
Expand Down

0 comments on commit 522a60a

Please sign in to comment.