Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Update mnist.py example #1270

Open
orion160 opened this issue Jun 12, 2024 · 10 comments
Open

[DOC] Update mnist.py example #1270

orion160 opened this issue Jun 12, 2024 · 10 comments

Comments

@orion160
Copy link

Update example at https://github.com/pytorch/examples/blob/main/mnist/main.py to use torch.compile features

@orion160
Copy link
Author

orion160 commented Jun 12, 2024

Proposal:

diff --git a/mnist/main.py b/mnist/main.py
index 184dc47..a3cffd1 100644
--- a/mnist/main.py
+++ b/mnist/main.py
@@ -3,13 +3,14 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
-from torchvision import datasets, transforms
+from torchvision import datasets
+from torchvision.transforms import v2 as transforms
 from torch.optim.lr_scheduler import StepLR
 
 
 class Net(nn.Module):
     def __init__(self):
-        super(Net, self).__init__()
+        super().__init__()
         self.conv1 = nn.Conv2d(1, 32, 3, 1)
         self.conv2 = nn.Conv2d(32, 64, 3, 1)
         self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ class Net(nn.Module):
         return output
 
 
-def train(args, model, device, train_loader, optimizer, epoch):
+def train_amp(args, model, device, train_loader, opt, epoch, scaler):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
-        data, target = data.to(device), target.to(device)
-        optimizer.zero_grad()
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
+        with torch.autocast(device_type=device.type):
+            output = model(data)
+            loss = F.nll_loss(output, target)
+        scaler.scale(loss).backward()
+        scaler.step(opt)
+        scaler.update()
+        if batch_idx % args.log_interval == 0:
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
+            if args.dry_run:
+                break
+
+
+def train(args, model, device, train_loader, opt, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
         output = model(data)
         loss = F.nll_loss(output, target)
         loss.backward()
-        optimizer.step()
+        opt.step()
         if batch_idx % args.log_interval == 0:
-            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
             if args.dry_run:
                 break
 
@@ -58,43 +82,125 @@ def test(model, device, test_loader):
         for data, target in test_loader:
             data, target = data.to(device), target.to(device)
             output = model(data)
-            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
-            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+            test_loss += F.nll_loss(
+                output, target, reduction="sum"
+            ).item()  # sum up batch loss
+            pred = output.argmax(
+                dim=1, keepdim=True
+            )  # get the index of the max log-probability
             correct += pred.eq(target.view_as(pred)).sum().item()
 
     test_loss /= len(test_loader.dataset)
 
-    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
-        test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+    print(
+        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
+    )
 
 
-def main():
+def parse_args():
     # Training settings
-    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
-    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
-                        help='input batch size for training (default: 64)')
-    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
-                        help='input batch size for testing (default: 1000)')
-    parser.add_argument('--epochs', type=int, default=14, metavar='N',
-                        help='number of epochs to train (default: 14)')
-    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
-                        help='learning rate (default: 1.0)')
-    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
-                        help='Learning rate step gamma (default: 0.7)')
-    parser.add_argument('--no-cuda', action='store_true', default=False,
-                        help='disables CUDA training')
-    parser.add_argument('--no-mps', action='store_true', default=False,
-                        help='disables macOS GPU training')
-    parser.add_argument('--dry-run', action='store_true', default=False,
-                        help='quickly check a single pass')
-    parser.add_argument('--seed', type=int, default=1, metavar='S',
-                        help='random seed (default: 1)')
-    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
-                        help='how many batches to wait before logging training status')
-    parser.add_argument('--save-model', action='store_true', default=False,
-                        help='For Saving the current Model')
+    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=64,
+        metavar="N",
+        help="input batch size for training (default: 64)",
+    )
+    parser.add_argument(
+        "--test-batch-size",
+        type=int,
+        default=1000,
+        metavar="N",
+        help="input batch size for testing (default: 1000)",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=14,
+        metavar="N",
+        help="number of epochs to train (default: 14)",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=1.0,
+        metavar="LR",
+        help="learning rate (default: 1.0)",
+    )
+    parser.add_argument(
+        "--gamma",
+        type=float,
+        default=0.7,
+        metavar="M",
+        help="Learning rate step gamma (default: 0.7)",
+    )
+    parser.add_argument(
+        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+    )
+    parser.add_argument(
+        "--no-mps",
+        action="store_true",
+        default=False,
+        help="disables macOS GPU training",
+    )
+    parser.add_argument(
+        "--dry-run",
+        action="store_true",
+        default=False,
+        help="quickly check a single pass",
+    )
+    parser.add_argument(
+        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+    )
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=10,
+        metavar="N",
+        help="how many batches to wait before logging training status",
+    )
+    parser.add_argument(
+        "--use-amp",
+        type=bool,
+        default=False,
+        help="use automatic mixed precision",
+    )
+    parser.add_argument(
+        "--compile-backend",
+        type=str,
+        default="inductor",
+        metavar="BACKEND",
+        help="backend to compile the model with",
+    )
+    parser.add_argument(
+        "--compile-mode",
+        type=str,
+        default="default",
+        metavar="MODE",
+        help="compilation mode",
+    )
+    parser.add_argument(
+        "--save-model",
+        action="store_true",
+        default=False,
+        help="For Saving the current Model",
+    )
+    parser.add_argument(
+        "--data-dir",
+        type=str,
+        default="../data",
+        metavar="DIR",
+        help="path to the data directory",
+    )
     args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()
 
@@ -107,32 +213,43 @@ def main():
     else:
         device = torch.device("cpu")
 
-    train_kwargs = {'batch_size': args.batch_size}
-    test_kwargs = {'batch_size': args.test_batch_size}
+    train_kwargs = {"batch_size": args.batch_size}
+    test_kwargs = {"batch_size": args.test_batch_size}
     if use_cuda:
-        cuda_kwargs = {'num_workers': 1,
-                       'pin_memory': True,
-                       'shuffle': True}
+        cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
         train_kwargs.update(cuda_kwargs)
         test_kwargs.update(cuda_kwargs)
 
-    transform=transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.1307,), (0.3081,))
-        ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
-    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+    transform = transforms.Compose(
+        [
+            transforms.ToImage(),
+            transforms.ToDtype(torch.float32, scale=True),
+            transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
+        ]
+    )
+
+    data_dir = args.data_dir
+
+    dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+    dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
+    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
 
-    model = Net().to(device)
-    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+    model = Net().to(device, memory_format=torch.channels_last)
+    model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
+    optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
 
     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+    scaler = None
+    if use_cuda and args.use_amp:
+        scaler = torch.GradScaler(device=device)
+
     for epoch in range(1, args.epochs + 1):
-        train(args, model, device, train_loader, optimizer, epoch)
+        if scaler is None:
+            train(args, model, device, train_loader, optimizer, epoch)
+        else:
+            train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
         test(model, device, test_loader)
         scheduler.step()
 
@@ -140,5 +257,5 @@ def main():
         torch.save(model.state_dict(), "mnist_cnn.pt")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

@orion160
Copy link
Author

CC @svekars

@orion160 orion160 changed the title [DOC] [DOCA 2024] Update mnist.py example [DOC] Update mnist.py example Jun 20, 2024
@doshi-kevin
Copy link

Hey, can I work on this issue ?

@doshi-kevin
Copy link

Hey @orion160 I have made the changes and create a pull request at #1282 . Could you please check once

@doshi-kevin
Copy link

@orion160 any updates on this ?

@orion160
Copy link
Author

@doshi-kevin I am not a PyTorch maintainer, it is great your interest on contributing.

Though I see 2 problems, first I you are applying the commit, which means that there's neither authoring nor coauthoring on the git history.

And the second, is that I see that your PR is about 600+ commits long, so it is possible that you did a merge and you created an altered git history which you want to push.

The usual approach would be to rebase your changes to head and then do the PR.

@doshi-kevin
Copy link

Sure I will make the changes, else create a new fork and a new pull request. @orion160

@doshi-kevin
Copy link

Hey @orion160 , please check #1283

@orion160
Copy link
Author

It seems good

@doshi-kevin
Copy link

Hope this gets pushed @orion160 @msaroufim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants