31
31
config = Config (
32
32
dict (
33
33
parallel = dict (
34
- zero1 = dict (size = - 1 , fsdp = False ),
35
- pipeline = dict (size = 1 , interleaved_overlap = False ),
36
- sequence_parallel = False ,
37
- tensor = 1 ,
34
+ zero1 = dict (size = - 1 ),
35
+ tensor = dict (size = 1 , mode = "mtp" ),
36
+ pipeline = dict ( size = 1 , interleaved_overlap = True ) ,
37
+ weight = dict ( size = 1 , overlap = True , memory_pool = True ) ,
38
38
),
39
39
data = dict (
40
40
seq_len = 2048 ,
47
47
valid_every = 300 ,
48
48
rampup_batch_size = None ,
49
49
diag_outlier_ratio = 1.1 ,
50
- train_folder = os .path .join (
51
- os .environ ["share_path" ], "quailty_assurance/0623_scratch_tokenized_filtered/train"
52
- ),
50
+ train_folder = None ,
53
51
valid_folder = os .path .join (
54
52
os .environ ["share_path" ], "quailty_assurance/0623_scratch_tokenized_filtered/val"
55
53
),
118
116
loss = dict (
119
117
label_smoothing = 0 ,
120
118
),
119
+ cudnn_deterministic = True ,
121
120
)
122
121
)
123
122
@@ -149,16 +148,6 @@ def seed_all(seed, cuda_deterministic=False):
149
148
torch .backends .cudnn .benchmark = True
150
149
151
150
152
- def load_new_batch (train_dl , train_iter ):
153
- try :
154
- batch = next (train_iter )
155
- except StopIteration :
156
- train_iter = iter (train_dl )
157
- batch = next (train_iter )
158
-
159
- return batch , train_iter
160
-
161
-
162
151
def evaluate_on_val_dls (
163
152
trainer ,
164
153
val_dls ,
@@ -241,7 +230,7 @@ def check_grad_norm(grad_norm_list):
241
230
242
231
logger .info (f"norm_mean: { tensor_trimmed_mean1 } , { tensor_trimmed_mean2 } " )
243
232
assert torch .allclose (tensor_trimmed_mean1 , tensor_trimmed_mean2 , rtol = 3e-1 , atol = 3e-1 )
244
- logger .info (f "grad norm check passed" )
233
+ logger .info ("grad norm check passed" )
245
234
246
235
247
236
def check_meanLoss_val (all_loss , all_val ):
@@ -258,10 +247,10 @@ def check_meanLoss_val(all_loss, all_val):
258
247
logger .info (f"avg_value: { trimmed_mean1 } , { trimmed_mean2 } " )
259
248
logger .info (f"all_val: { all_val } " )
260
249
261
- assert torch .allclose (tensor_trimmed_mean1 , tensor_trimmed_mean2 , rtol = 3e -2 , atol = 3e -2 )
262
- assert torch .allclose (torch .tensor (all_val [0 ]), torch .tensor (all_val [1 ]), rtol = 3e -2 , atol = 3e -2 )
250
+ assert torch .allclose (tensor_trimmed_mean1 , tensor_trimmed_mean2 , rtol = 5e -2 , atol = 5e -2 )
251
+ assert torch .allclose (torch .tensor (all_val [0 ]), torch .tensor (all_val [1 ]), rtol = 5e -2 , atol = 5e -2 )
263
252
264
- logger .info (f "loss check passed" )
253
+ logger .info ("loss check passed" )
265
254
266
255
267
256
def exam_loss (args ):
@@ -321,16 +310,18 @@ def exam_loss(args):
321
310
)
322
311
323
312
trainer .train ()
324
- train_iter = iter (train_dl )
325
313
326
314
# transfer the train data loader into train data iterator
327
315
loss_list = []
328
316
val_list = []
329
317
grad_norm_list = []
318
+ share_data_path = os .environ ["share_data_path" ]
319
+ data_path = os .path .join (share_data_path , "quality_assurance/0623_data_batch" )
330
320
for batch_count in range (total_steps ):
331
321
start_time = time .time ()
332
322
# load batch data
333
- batch , train_iter = load_new_batch (train_dl = train_dl , train_iter = train_iter )
323
+ batch_path = os .path .join (data_path , f"batch_{ batch_count } _{ rank } .pt" )
324
+ batch = torch .load (batch_path )
334
325
335
326
# zero the grads of parameters
336
327
trainer .zero_grad ()
0 commit comments