Skip to content

Commit

Permalink
Improve and refine MLP tests for extensibility and A/B testing (#8561)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 17, 2025
1 parent 7f75b30 commit 0dd7d63
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 136 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
Expand Down
Empty file added test/spmd/__init__.py
Empty file.
191 changes: 55 additions & 136 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,57 @@
import args_parse
import numpy as np
import torch
from torch import nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.utils.checkpoint as checkpoint
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
import torch.optim as optim
from torch import nn

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 1024,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = args_parse.parse_common_options(
batch_size=128, num_epochs=1, opts=MODEL_OPTS.items())

xr.use_spmd(auto=FLAGS.auto_spmd)


class SimpleLinear(nn.Module):

def __init__(self):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(FLAGS.input_dim // 2, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)

def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)


device = xm.xla_device()


def train():
print('===> Preparing data..')
lr = 0.1
train_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim),
torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
torch.manual_seed(42)
model = SimpleLinear().to(device)
import argparse
from contextlib import contextmanager
import os
import sys
import unittest

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
xs.mark_sharding(model.fc2.weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.fc1.weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.fc2.weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(model):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = torch_xla.utils.checkpoint.checkpoint(layer, x)
output = x
else:
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
xm.mark_step()
if step % 10 == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)

return model


if FLAGS.profile:
server = xp.start_server(FLAGS.profiler_port)
import torch

print('Start training loop...')
m = train()
t = torch.randn(10, FLAGS.input_dim).to(device)
m(t).cpu()
import test_xla_sharding_base

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from utils.train_spmd_linear_model import train_and_evaluate

SKIP_GRADIENT_CHECKPOINTING: bool = False


@contextmanager
def extended_argv(args):
original_argv = sys.argv[:]
sys.argv.extend(args)
try:
yield
finally:
sys.argv = original_argv


class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):

def test_basic(self):
print('Training loop with baseline')
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in baseline_losses)
# Verify that the model produces non-zero outputs.
assert not torch.any(baseline_result == 0)

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with gradient checkpointing')
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()
# Verify that the runs match with and without checkpointing.
assert torch.allclose(baseline_result, checkpointing_result)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
for baseline_loss, checkpointing_loss in zip(
baseline_losses, checkpointing_losses))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--skip-gradient-checkpointing', action='store_true')
parsed_args, remaining_argv = parser.parse_known_args()
SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing
test = unittest.main(argv=[sys.argv[0]] + remaining_argv)
sys.exit(0 if test.result.wasSuccessful() else 1)
Empty file added test/utils/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions test/utils/train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import sys
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = {}
PROFILER_SERVER = None


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu() for loss in losses], m(t).cpu()

0 comments on commit 0dd7d63

Please sign in to comment.