Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 13, 2023
1 parent 3a75d53 commit 90bbe88
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from cnn import CNN # Import the CNN class

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
Expand All @@ -16,19 +17,14 @@
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)
# Create an instance of the CNN class
cnn = CNN()

# Train the CNN
cnn.train(trainloader, lr=0.001, epochs=10)

# Save the trained model
cnn.save_model("mnist_model.pth")

# Step 3: Train the Model
model = Net()
Expand Down

0 comments on commit 90bbe88

Please sign in to comment.