From 6f6638f69035db64be8e1f2cfb81e5288000138e Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 26 Nov 2023 01:21:06 +0000 Subject: [PATCH] Sandbox run src/cnn.py --- src/cnn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/cnn.py b/src/cnn.py index 4dd0d15..601e776 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -30,6 +30,7 @@ def forward(self, x): output = nn.functional.log_softmax(x, dim=1) return output + def train(network, train_loader, optimizer): network.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -39,13 +40,13 @@ def train(network, train_loader, optimizer): loss.backward() optimizer.step() + if __name__ == "__main__": - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] + ) - trainset = datasets.MNIST('.', download=True, train=True, transform=transform) + trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) network = CNN()