diff --git a/src/main.py b/src/main.py index 243a31e..f361f2e 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN # Import the CNN class # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -16,19 +17,14 @@ trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) +# Create an instance of the CNN class +cnn = CNN() + +# Train the CNN +cnn.train(trainloader, lr=0.001, epochs=10) + +# Save the trained model +cnn.save_model("mnist_model.pth") # Step 3: Train the Model model = Net()