From 90bbe88f7d4709422a3e1e48d9b987498777632b Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 23:09:52 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..f361f2e 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN # Import the CNN class # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -16,19 +17,14 @@ trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) +# Create an instance of the CNN class +cnn = CNN() + +# Train the CNN +cnn.train(trainloader, lr=0.001, epochs=10) + +# Save the trained model +cnn.save_model("mnist_model.pth") # Step 3: Train the Model model = Net()