Skip to content

Commit

Permalink
Update model.py (#84)
Browse files Browse the repository at this point in the history
If tie_weights and having lm_target, the lm_target should be consistent with tgt_embedding first.
  • Loading branch information
Eric8932 authored Aug 24, 2023
1 parent 37144d1 commit 5d7c55c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tencentpretrain/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target):

if "mlm" in args.target and args.tie_weights:
self.target.mlm.linear_2.weight = self.embedding.word.embedding.weight
elif "lm" in args.target and args.tie_weights and self.tgt_embedding is not None and "word" in self.tgt_embedding.embedding_name_list:
self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight
elif "lm" in args.target and args.tie_weights and "word" in self.embedding.embedding_name_list:
self.target.lm.output_layer.weight = self.embedding.word.embedding.weight
elif "lm" in args.target and args.tie_weights and "word" in self.tgt_embedding.embedding_name_list:
self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight

if self.decoder is not None and args.share_embedding:
self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight
Expand Down

0 comments on commit 5d7c55c

Please sign in to comment.