Skip to content

Commit efdeda6

Browse files
Sandbox run src/main.py
1 parent 375f6d6 commit efdeda6

File tree

1 file changed

+7
-22
lines changed

1 file changed

+7
-22
lines changed

src/main.py

+7-22
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,21 @@
22
import torch
33
import torch.nn as nn
44
import torch.optim as optim
5-
from cnn import CNN, train
65
from PIL import Image
76
from torch.utils.data import DataLoader
87
from torchvision import datasets, transforms
98

9+
from cnn import CNN, train
10+
1011
# Step 1: Load MNIST Data and Preprocess
11-
transform = transforms.Compose([
12-
transforms.ToTensor(),
13-
transforms.Normalize((0.5,), (0.5,))
14-
])
12+
transform = transforms.Compose(
13+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
14+
)
1515

16-
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
16+
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
1717
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
1818

1919

20-
21-
22-
23-
24-
25-
26-
27-
28-
29-
30-
31-
32-
33-
34-
3520
model = Net()
3621
model = CNN()
3722
optimizer = optim.SGD(model.parameters(), lr=0.01)
@@ -41,4 +26,4 @@
4126
train(model, trainloader, optimizer)
4227

4328

44-
torch.save(model.state_dict(), "mnist_model.pth")
29+
torch.save(model.state_dict(), "mnist_model.pth")

0 commit comments

Comments
 (0)