diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 0915b9dc1..eba63dfb3 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -276,14 +276,16 @@ def main(): num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa - # Freeze other parts of the model during training context bias module if 'context_module_conf' in configs: + # Freeze other parts of the model during training context bias module for p in model.parameters(): p.requires_grad = False for p in model.context_module.parameters(): p.requires_grad = True for p in model.context_module.context_decoder_ctc_linear.parameters(): p.requires_grad = False + # Turn off dynamic chunk because it will affect the training of bias + model.encoder.use_dynamic_chunk = False # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py index 2a6019bf8..d12d9d511 100644 --- a/wenet/transformer/context_module.py +++ b/wenet/transformer/context_module.py @@ -51,8 +51,8 @@ def forward(self, sen_batch, sen_lengths): _, last_state = self.sen_rnn(pack_seq) laste_h = last_state[0] laste_c = last_state[1] - state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], - laste_c[-1, :, :], laste_c[0, :, :]], dim=-1) + state = torch.cat([laste_h[-1, :, :], laste_h[-2, :, :], + laste_c[-1, :, :], laste_c[-2, :, :]], dim=-1) return state