@@ -88,7 +88,7 @@ def pretrain_validation(args, index, model):
88
88
logger = args .logger
89
89
90
90
model .eval ()
91
- dataset = PreTrainingDataset (args .tokenizer , config ['validation' ]['path' ], args .logger ,
91
+ dataset = PreTrainingDataset (args .tokenizer , os . path . join ( args . data_path_prefix , config ['validation' ]['path' ]) , args .logger ,
92
92
args .max_seq_length , index , PretrainDataType .VALIDATION , args .max_predictions_per_seq )
93
93
data_batches = get_dataloader (args , dataset , eval_set = True )
94
94
eval_loss = 0
@@ -130,7 +130,7 @@ def get_train_dataset(args, index, finetune=False, shuffle=True):
130
130
# Load Wiki Dataset
131
131
wiki_pretrain_dataset = PreTrainingDataset (
132
132
args .tokenizer ,
133
- dataset_paths ['wiki_pretrain_dataset' ],
133
+ os . path . join ( args . data_path_prefix , dataset_paths ['wiki_pretrain_dataset' ]) ,
134
134
args .logger ,
135
135
args .max_seq_length ,
136
136
index ,
@@ -145,7 +145,7 @@ def get_train_dataset(args, index, finetune=False, shuffle=True):
145
145
146
146
bc_pretrain_dataset = PreTrainingDataset (
147
147
args .tokenizer ,
148
- dataset_paths ['bc_pretrain_dataset' ],
148
+ os . path . join ( args . data_path_prefix , dataset_paths ['bc_pretrain_dataset' ]) ,
149
149
args .logger ,
150
150
args .max_seq_length ,
151
151
index ,
0 commit comments