Skip to content

Commit 671926b

Browse files
authored
fix(QA): fix test_swap_nb_loss_and_gradnorm (#63)
1 parent 086fd3e commit 671926b

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

tests/test_training/test_swap_nb_loss_and_gradnorm.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
config = Config(
3232
dict(
3333
parallel=dict(
34-
zero1=dict(size=-1, fsdp=False),
35-
pipeline=dict(size=1, interleaved_overlap=False),
36-
sequence_parallel=False,
37-
tensor=1,
34+
zero1=dict(size=-1),
35+
tensor=dict(size=1, mode="mtp"),
36+
pipeline=dict(size=1, interleaved_overlap=True),
37+
weight=dict(size=1, overlap=True, memory_pool=True),
3838
),
3939
data=dict(
4040
seq_len=2048,
@@ -47,9 +47,7 @@
4747
valid_every=300,
4848
rampup_batch_size=None,
4949
diag_outlier_ratio=1.1,
50-
train_folder=os.path.join(
51-
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/train"
52-
),
50+
train_folder=None,
5351
valid_folder=os.path.join(
5452
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/val"
5553
),
@@ -118,6 +116,7 @@
118116
loss=dict(
119117
label_smoothing=0,
120118
),
119+
cudnn_deterministic=True,
121120
)
122121
)
123122

@@ -149,16 +148,6 @@ def seed_all(seed, cuda_deterministic=False):
149148
torch.backends.cudnn.benchmark = True
150149

151150

152-
def load_new_batch(train_dl, train_iter):
153-
try:
154-
batch = next(train_iter)
155-
except StopIteration:
156-
train_iter = iter(train_dl)
157-
batch = next(train_iter)
158-
159-
return batch, train_iter
160-
161-
162151
def evaluate_on_val_dls(
163152
trainer,
164153
val_dls,
@@ -241,7 +230,7 @@ def check_grad_norm(grad_norm_list):
241230

242231
logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
243232
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
244-
logger.info(f"grad norm check passed")
233+
logger.info("grad norm check passed")
245234

246235

247236
def check_meanLoss_val(all_loss, all_val):
@@ -258,10 +247,10 @@ def check_meanLoss_val(all_loss, all_val):
258247
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
259248
logger.info(f"all_val: {all_val}")
260249

261-
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
262-
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
250+
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=5e-2, atol=5e-2)
251+
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=5e-2, atol=5e-2)
263252

264-
logger.info(f"loss check passed")
253+
logger.info("loss check passed")
265254

266255

267256
def exam_loss(args):
@@ -321,16 +310,18 @@ def exam_loss(args):
321310
)
322311

323312
trainer.train()
324-
train_iter = iter(train_dl)
325313

326314
# transfer the train data loader into train data iterator
327315
loss_list = []
328316
val_list = []
329317
grad_norm_list = []
318+
share_data_path = os.environ["share_data_path"]
319+
data_path = os.path.join(share_data_path, "quality_assurance/0623_data_batch")
330320
for batch_count in range(total_steps):
331321
start_time = time.time()
332322
# load batch data
333-
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
323+
batch_path = os.path.join(data_path, f"batch_{batch_count}_{rank}.pt")
324+
batch = torch.load(batch_path)
334325

335326
# zero the grads of parameters
336327
trainer.zero_grad()

0 commit comments

Comments
 (0)