From 5d7c55caf323be61bd3599ecd81e7ea08d0a2ab2 Mon Sep 17 00:00:00 2001 From: WenH <80993860+Eric8932@users.noreply.github.com> Date: Thu, 24 Aug 2023 10:31:56 +0800 Subject: [PATCH] Update model.py (#84) If tie_weights and having lm_target, the lm_target should be consistent with tgt_embedding first. --- tencentpretrain/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tencentpretrain/models/model.py b/tencentpretrain/models/model.py index 5931f12..f7b30d6 100755 --- a/tencentpretrain/models/model.py +++ b/tencentpretrain/models/model.py @@ -21,10 +21,10 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target): if "mlm" in args.target and args.tie_weights: self.target.mlm.linear_2.weight = self.embedding.word.embedding.weight + elif "lm" in args.target and args.tie_weights and self.tgt_embedding is not None and "word" in self.tgt_embedding.embedding_name_list: + self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight elif "lm" in args.target and args.tie_weights and "word" in self.embedding.embedding_name_list: self.target.lm.output_layer.weight = self.embedding.word.embedding.weight - elif "lm" in args.target and args.tie_weights and "word" in self.tgt_embedding.embedding_name_list: - self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight if self.decoder is not None and args.share_embedding: self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight