diff --git a/src/api.py b/src/api.py index 36c257a..4a1dc79 100644 --- a/src/api.py +++ b/src/api.py @@ -1,28 +1,55 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image +""" +This file loads the model, preprocesses the image, and makes predictions using the FastAPI framework. +""" + +# PyTorch is an open source machine learning library based on the Torch library import torch + +# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. +from fastapi import FastAPI, File, UploadFile + +# PIL is used for opening, manipulating, and saving many different image file formats +from PIL import Image + +# torchvision is a library for PyTorch that provides datasets and models for computer vision from torchvision import transforms -from main import Net # Importing Net class from main.py -# Load the model +# Importing Net class from main.py +from main import Net + +# Load the model and set it to evaluation mode model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# Transform used for preprocessing the image. It converts the image to tensor and normalizes it. +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) +# FastAPI application app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function takes an image file as input, preprocesses the image, and makes a prediction using the loaded model. + Parameters: + file (UploadFile): The image file to be processed and predicted. + Returns: + dict: A dictionary with the prediction result. + """ + # Open the image file and convert it to grayscale image = Image.open(file.file).convert("L") + # Apply the transform to the image image = transform(image) - image = image.unsqueeze(0) # Add batch dimension + # Add a batch dimension to the image + image = image.unsqueeze(0) + # Make a prediction without calculating gradients with torch.no_grad(): output = model(image) + # Get the index of the max log-probability _, predicted = torch.max(output.data, 1) + # Return the prediction result return {"prediction": int(predicted[0])} diff --git a/src/main.py b/src/main.py index 243a31e..8285f4b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,28 +1,53 @@ -from PIL import Image +""" +This file loads and preprocesses the MNIST dataset and defines a PyTorch model. +""" + +# NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices +import numpy as np + +# PyTorch is an open source machine learning library based on the Torch library import torch + +# torch.nn is a sublibrary of PyTorch, provides classes to build neural networks import torch.nn as nn + +# torch.optim is a package implementing various optimization algorithms import torch.optim as optim -from torchvision import datasets, transforms + +# DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset from torch.utils.data import DataLoader -import numpy as np + +# torchvision is a library for PyTorch that provides datasets and models for computer vision +from torchvision import datasets, transforms + +# PIL is used for opening, manipulating, and saving many different image file formats + # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# transforms.Compose composes several transforms together, in this case, toTensor and Normalize +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +# datasets.MNIST loads the MNIST dataset from the root directory with the specified transforms +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) +# DataLoader combines the dataset and a sampler, and provides an iterable over the given dataset trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): + """ + This class defines a simple feed-forward neural network for the MNIST dataset. + It has three fully connected layers and uses ReLU activation function for the first two layers and log softmax for the output layer. + """ + def __init__(self): super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) @@ -30,6 +55,7 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) @@ -45,4 +71,4 @@ def forward(self, x): loss.backward() optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")