Skip to content

Commit

Permalink
Revert the result validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 22, 2025
1 parent 83f0af1 commit 3cd4542
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
26 changes: 9 additions & 17 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,17 @@ class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):
def test_basic(self):
print('Training loop with baseline')
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()
baseline_losses = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in baseline_losses)
# Verify that the model produces non-zero outputs.
assert not torch.any(baseline_result == 0)

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with gradient checkpointing')
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()
checkpointing_losses = train_and_evaluate()
# Verify that the runs match with and without checkpointing.
assert torch.allclose(
baseline_result, checkpointing_result, rtol=1e-06, atol=1e-09)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
torch.allclose(
baseline_loss, checkpointing_loss, rtol=1e-3, atol=1e-7)
for baseline_loss, checkpointing_loss in zip(
baseline_losses, checkpointing_losses))

Expand All @@ -55,30 +51,26 @@ class TestSPMDLinearModelGradientAccumulation(
test_xla_sharding_base.XlaShardingTest):

def test_gradient_accumulation_matches(self):
"""Verify that gradient accumulation produces the same results and losses
with and without the XLA `While` op.
"""Verify that gradient accumulation produces the same losses with and
without the XLA `While` op.
"""

COMMON_GRAD_ACC_ARGS = ["--gradient_accumulation_steps", "8"]
print('Training loop with traditional gradient accumulation')
with extended_argv(COMMON_GRAD_ACC_ARGS):
baseline_grad_acc_losses, baseline_grad_acc_result = train_and_evaluate()
baseline_grad_acc_losses = train_and_evaluate()

print('Training loop with XLA\'s `While` gradient accumulation')
with extended_argv(COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop"]):
loop_grad_acc_losses, loop_grad_acc_result = train_and_evaluate()
loop_grad_acc_losses = train_and_evaluate()

# Verify that the model losses are not zero, and that the runs match.
assert all(loss != 0 for loss in baseline_grad_acc_losses)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-3, atol=1e-7)
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
loop_grad_acc_losses))
# Verify that the model produces non-zero outputs, and that the runs match.
assert not torch.any(baseline_grad_acc_result == 0)
assert torch.allclose(
baseline_grad_acc_result, loop_grad_acc_result, rtol=1e-06, atol=1e-09)


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion test/utils/train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,5 @@ def train_and_evaluate():
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu() for loss in losses], m(t).cpu()
m(t).cpu()
return [loss.cpu() for loss in losses]

0 comments on commit 3cd4542

Please sign in to comment.