diff --git a/src/api.py b/src/api.py index 36c257a..475b5c0 100644 --- a/src/api.py +++ b/src/api.py @@ -1,24 +1,48 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image +""" +This script creates a FastAPI application for making predictions using the PyTorch model defined in main.py. +""" +# torch is the main PyTorch library import torch + +# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. +from fastapi import FastAPI, File, UploadFile + +# PIL is used for opening, manipulating, and saving many different image file formats +from PIL import Image + +# torchvision.transforms provides classes for transforming images from torchvision import transforms -from main import Net # Importing Net class from main.py -# Load the model +# Importing Net class from main.py +from main import Net + +# Load the trained PyTorch model from a file model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# Define a sequence of preprocessing steps to be applied to the input images +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) +# Create an instance of the FastAPI application app = FastAPI() + +# Define a route handler for making predictions using the model @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function is a route handler for making predictions using the model. + It takes an image file as input and returns a prediction. + + Parameters: + file (UploadFile): The image file to predict. + + Returns: + dict: A dictionary with the prediction. + """ image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension diff --git a/src/main.py b/src/main.py index 243a31e..7702210 100644 --- a/src/main.py +++ b/src/main.py @@ -1,29 +1,54 @@ -from PIL import Image +""" +This script loads and preprocesses the MNIST dataset, defines a PyTorch model, and trains the model. +""" +# PIL is used for opening, manipulating, and saving many different image file formats +import PIL.Image +# torch is the main PyTorch library import torch +# torch.nn provides classes for building neural networks import torch.nn as nn +# torch.optim provides classes for implementing various optimization algorithms import torch.optim as optim -from torchvision import datasets, transforms +# torchvision.datasets provides classes for loading and using various popular datasets +from torchvision import datasets +# torchvision.transforms provides classes for transforming images +from torchvision import transforms +# torch.utils.data provides classes for loading data in parallel from torch.utils.data import DataLoader +# numpy is used for numerical operations import numpy as np # Step 1: Load MNIST Data and Preprocess +# This is a sequence of preprocessing steps to be applied to the images in the MNIST dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# This represents the MNIST training dataset trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +# This is a data loader for batching and shuffling the training data trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model +# This class defines the architecture of the PyTorch model class Net(nn.Module): + """ + This class defines the architecture of the PyTorch model. + """ def __init__(self): + """ + This method initializes the model. + """ super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): + """ + This method defines the forward pass of the model. + """ x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x))