Skip to content

Commit 53a51f8

Browse files
Sandbox run src/cnn.py
1 parent 7cfe902 commit 53a51f8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/cnn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def train_cnn(self, trainloader, epochs=3):
3636

3737
torch.save(self.state_dict(), "mnist_cnn_model.pth")
3838

39-
transform = transforms.Compose([
40-
transforms.ToTensor(),
41-
transforms.Normalize((0.5,), (0.5,))
42-
])
4339

44-
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
40+
transform = transforms.Compose(
41+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
42+
)
43+
44+
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
4545
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
4646

4747
cnn = CNN()

0 commit comments

Comments
 (0)