Skip to content

Commit 90bbe88

Browse files
feat: Updated src/main.py
1 parent 3a75d53 commit 90bbe88

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

src/main.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchvision import datasets, transforms
66
from torch.utils.data import DataLoader
77
import numpy as np
8+
from cnn import CNN # Import the CNN class
89

910
# Step 1: Load MNIST Data and Preprocess
1011
transform = transforms.Compose([
@@ -16,19 +17,14 @@
1617
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
1718

1819
# Step 2: Define the PyTorch Model
19-
class Net(nn.Module):
20-
def __init__(self):
21-
super().__init__()
22-
self.fc1 = nn.Linear(28 * 28, 128)
23-
self.fc2 = nn.Linear(128, 64)
24-
self.fc3 = nn.Linear(64, 10)
25-
26-
def forward(self, x):
27-
x = x.view(-1, 28 * 28)
28-
x = nn.functional.relu(self.fc1(x))
29-
x = nn.functional.relu(self.fc2(x))
30-
x = self.fc3(x)
31-
return nn.functional.log_softmax(x, dim=1)
20+
# Create an instance of the CNN class
21+
cnn = CNN()
22+
23+
# Train the CNN
24+
cnn.train(trainloader, lr=0.001, epochs=10)
25+
26+
# Save the trained model
27+
cnn.save_model("mnist_model.pth")
3228

3329
# Step 3: Train the Model
3430
model = Net()

0 commit comments

Comments
 (0)