Open
Description
I tried to execute the code with BART using MBartForConditionalGeneration
So
In Bart.py I removed the lines
class MiniBART(MBartModel):
def __init__(self, config):
super().__init__(config)
self.dst_decoder = type(self.decoder)(config, self.shared)
self.dst_decoder.load_state_dict(self.decoder.state_dict())
def tie_decoder(self):
self.shared.padding_idx = self.config.pad_token_id
self.dst_decoder = type(self.decoder)(self.config, self.shared)
self.dst_decoder.load_state_dict(self.decoder.state_dict())
and used these lines instead
class SimBART(MBartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
def tie_decoder(self):
pass
but I got an error when run with
python train.py --mode train --context_window 2 --pretrained_checkpoint facebook/mbart-large-50 --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=557 batch_size=8
the error is
Traceback (most recent call last):
File "C:\Users\E\train.py", line 363, in <module>
main()
File "C:\Users\E\train.py", line 351, in main
m = Model(args)
File "C:\Users\train.py", line 35, in __init__
self.model = SimBART.from_pretrained(args.model_path if test else 'facebook/mbart-large-50')
File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 1224, in from_pretrained
model.tie_weights()
File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 522, in tie_weights
output_embeddings = self.get_output_embeddings()
File "C:\Users\E\BART.py", line 193, in get_output_embeddings
return _make_linear_from_emb(self.shared) # make it on the fly
File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 947, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SimBART' object has no attribute 'shared'
what can I do to fix this error?
Metadata
Metadata
Assignees
Labels
No labels