diff --git a/src/api.py b/src/api.py index 36c257a..ea5eb1a 100644 --- a/src/api.py +++ b/src/api.py @@ -1,3 +1,8 @@ +""" +This module is used to serve a PyTorch model as a FastAPI service. It includes the necessary steps to load the model, +preprocess the input image, and make a prediction. +""" + from fastapi import FastAPI, UploadFile, File from PIL import Image import torch @@ -5,11 +10,13 @@ from main import Net # Importing Net class from main.py # Load the model +# The model is loaded from the saved state dictionary and set to evaluation mode. model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() # Transform used for preprocessing the image +# The transform is used to convert the input image to a tensor and normalize it. transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) @@ -19,10 +26,29 @@ @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function takes an uploaded file as input, preprocesses the image, makes a prediction using the model, + and returns the prediction. + + Parameters: + file: The uploaded file containing the image. + + Returns: + A dictionary containing the prediction made by the model. + """ + # Open the image file and convert it to grayscale image = Image.open(file.file).convert("L") + + # Preprocess the image using the transform image = transform(image) + + # Add a batch dimension to the image image = image.unsqueeze(0) # Add batch dimension + + # Make a prediction using the model with torch.no_grad(): output = model(image) _, predicted = torch.max(output.data, 1) + + # Return the prediction return {"prediction": int(predicted[0])} diff --git a/src/main.py b/src/main.py index 243a31e..75f1759 100644 --- a/src/main.py +++ b/src/main.py @@ -1,48 +1,87 @@ -from PIL import Image +""" +This module is used to train a PyTorch model on the MNIST dataset. It includes the necessary steps to preprocess the data, +define the model, and train the model. +""" + import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# The transform variable is used to preprocess the MNIST data by converting the images to tensors and normalizing them. +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): + """ + This class defines a simple feed-forward neural network for the MNIST dataset. It includes three fully connected layers. + + Attributes: + fc1: The first fully connected layer. + fc2: The second fully connected layer. + fc3: The third fully connected layer. + """ + def __init__(self): + """ + Initializes the Net class by defining the three fully connected layers. + """ super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): + """ + Defines the forward pass of the network. + + Parameters: + x: The input tensor. + + Returns: + The output tensor after passing through the network. + """ x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() # Training loop +# Define the number of training epochs epochs = 3 -for epoch in range(epochs): + +# Start the training loop +for _ in range(epochs): + # For each batch of images and labels in the trainloader for images, labels in trainloader: + # Zero the gradients optimizer.zero_grad() + + # Forward pass: compute the output of the model on the images output = model(images) + + # Compute the loss between the output and the labels loss = criterion(output, labels) + + # Backward pass: compute the gradients of the loss with respect to the model parameters loss.backward() + + # Update the model parameters optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")