Skip to content

Commit 6a698b2

Browse files
authored
aml datastore support using a data_path_prefix (deepspeedai#16)
1 parent 91a794c commit 6a698b2

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

bing_bert/bert_large_lamb_seq128.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
"pretrain_type": "wiki_bc"
2424
},
2525
"datasets": {
26-
"wiki_pretrain_dataset": "/data/bert/bnorick_format/128/wiki_pretrain",
27-
"bc_pretrain_dataset": "/data/bert/bnorick_format/128/bookcorpus_pretrain"
26+
"wiki_pretrain_dataset": "bnorick_format/128/wiki_pretrain",
27+
"bc_pretrain_dataset": "bnorick_format/128/bookcorpus_pretrain"
2828
},
2929
"tp1pp_evalsets":
3030
{
@@ -45,6 +45,6 @@
4545
"total_training_steps": 187000
4646
},
4747
"validation": {
48-
"path": "/data/bert/validation_set/"
48+
"path": "validation_set/"
4949
}
5050
}

bing_bert/deepspeed_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def pretrain_validation(args, index, model):
8888
logger = args.logger
8989

9090
model.eval()
91-
dataset = PreTrainingDataset(args.tokenizer, config['validation']['path'], args.logger,
91+
dataset = PreTrainingDataset(args.tokenizer, os.path.join(args.data_path_prefix, config['validation']['path']), args.logger,
9292
args.max_seq_length, index, PretrainDataType.VALIDATION, args.max_predictions_per_seq)
9393
data_batches = get_dataloader(args, dataset, eval_set=True)
9494
eval_loss = 0
@@ -130,7 +130,7 @@ def get_train_dataset(args, index, finetune=False, shuffle=True):
130130
# Load Wiki Dataset
131131
wiki_pretrain_dataset = PreTrainingDataset(
132132
args.tokenizer,
133-
dataset_paths['wiki_pretrain_dataset'],
133+
os.path.join(args.data_path_prefix, dataset_paths['wiki_pretrain_dataset']),
134134
args.logger,
135135
args.max_seq_length,
136136
index,
@@ -145,7 +145,7 @@ def get_train_dataset(args, index, finetune=False, shuffle=True):
145145

146146
bc_pretrain_dataset = PreTrainingDataset(
147147
args.tokenizer,
148-
dataset_paths['bc_pretrain_dataset'],
148+
os.path.join(args.data_path_prefix, dataset_paths['bc_pretrain_dataset']),
149149
args.logger,
150150
args.max_seq_length,
151151
index,

bing_bert/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def get_argument_parser():
8080
default=100,
8181
help='Interval to print training details.')
8282

83+
parser.add_argument('--data_path_prefix',
84+
type=str,
85+
default="",
86+
help="Path to prefix data loading, helpful for AML and other environments")
87+
8388
return parser
8489

8590
def is_time_to_exit(args, epoch_steps=0, global_steps=0):

0 commit comments

Comments
 (0)