diff --git a/src/cnn.py b/src/cnn.py index 4dd0d15..601e776 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -30,6 +30,7 @@ def forward(self, x): output = nn.functional.log_softmax(x, dim=1) return output + def train(network, train_loader, optimizer): network.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -39,13 +40,13 @@ def train(network, train_loader, optimizer): loss.backward() optimizer.step() + if __name__ == "__main__": - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] + ) - trainset = datasets.MNIST('.', download=True, train=True, transform=transform) + trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) network = CNN()