Skip to content

Commit 375f6d6

Browse files
feat: Updated src/main.py
1 parent 6f6638f commit 375f6d6

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

src/main.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,44 @@
1-
from PIL import Image
1+
import numpy as np
22
import torch
33
import torch.nn as nn
44
import torch.optim as optim
5-
from torchvision import datasets, transforms
5+
from cnn import CNN, train
6+
from PIL import Image
67
from torch.utils.data import DataLoader
7-
import numpy as np
8+
from torchvision import datasets, transforms
89

910
# Step 1: Load MNIST Data and Preprocess
1011
transform = transforms.Compose([
11-
transforms.ToTensor(),
12+
transforms.ToTensor(),
1213
transforms.Normalize((0.5,), (0.5,))
1314
])
1415

1516
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
1617
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
1718

18-
# Step 2: Define the PyTorch Model
19-
class Net(nn.Module):
20-
def __init__(self):
21-
super().__init__()
22-
self.fc1 = nn.Linear(28 * 28, 128)
23-
self.fc2 = nn.Linear(128, 64)
24-
self.fc3 = nn.Linear(64, 10)
19+
20+
21+
22+
23+
24+
25+
2526

26-
def forward(self, x):
27-
x = x.view(-1, 28 * 28)
28-
x = nn.functional.relu(self.fc1(x))
29-
x = nn.functional.relu(self.fc2(x))
30-
x = self.fc3(x)
31-
return nn.functional.log_softmax(x, dim=1)
32-
33-
# Step 3: Train the Model
27+
28+
29+
30+
31+
32+
33+
34+
3435
model = Net()
36+
model = CNN()
3537
optimizer = optim.SGD(model.parameters(), lr=0.01)
3638
criterion = nn.NLLLoss()
3739

3840
# Training loop
39-
epochs = 3
40-
for epoch in range(epochs):
41-
for images, labels in trainloader:
42-
optimizer.zero_grad()
43-
output = model(images)
44-
loss = criterion(output, labels)
45-
loss.backward()
46-
optimizer.step()
41+
train(model, trainloader, optimizer)
42+
4743

4844
torch.save(model.state_dict(), "mnist_model.pth")

0 commit comments

Comments
 (0)