Skip to content

Commit

Permalink
fix(QA): fix test_swap_nb_loss_and_gradnorm (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Feb 29, 2024
1 parent 086fd3e commit 671926b
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions tests/test_training/test_swap_nb_loss_and_gradnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
config = Config(
dict(
parallel=dict(
zero1=dict(size=-1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
),
data=dict(
seq_len=2048,
Expand All @@ -47,9 +47,7 @@
valid_every=300,
rampup_batch_size=None,
diag_outlier_ratio=1.1,
train_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/train"
),
train_folder=None,
valid_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/val"
),
Expand Down Expand Up @@ -118,6 +116,7 @@
loss=dict(
label_smoothing=0,
),
cudnn_deterministic=True,
)
)

Expand Down Expand Up @@ -149,16 +148,6 @@ def seed_all(seed, cuda_deterministic=False):
torch.backends.cudnn.benchmark = True


def load_new_batch(train_dl, train_iter):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)

return batch, train_iter


def evaluate_on_val_dls(
trainer,
val_dls,
Expand Down Expand Up @@ -241,7 +230,7 @@ def check_grad_norm(grad_norm_list):

logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
logger.info(f"grad norm check passed")
logger.info("grad norm check passed")


def check_meanLoss_val(all_loss, all_val):
Expand All @@ -258,10 +247,10 @@ def check_meanLoss_val(all_loss, all_val):
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
logger.info(f"all_val: {all_val}")

assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=5e-2, atol=5e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=5e-2, atol=5e-2)

logger.info(f"loss check passed")
logger.info("loss check passed")


def exam_loss(args):
Expand Down Expand Up @@ -321,16 +310,18 @@ def exam_loss(args):
)

trainer.train()
train_iter = iter(train_dl)

# transfer the train data loader into train data iterator
loss_list = []
val_list = []
grad_norm_list = []
share_data_path = os.environ["share_data_path"]
data_path = os.path.join(share_data_path, "quality_assurance/0623_data_batch")
for batch_count in range(total_steps):
start_time = time.time()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
batch_path = os.path.join(data_path, f"batch_{batch_count}_{rank}.pt")
batch = torch.load(batch_path)

# zero the grads of parameters
trainer.zero_grad()
Expand Down

0 comments on commit 671926b

Please sign in to comment.