Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #22

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""
This module contains a FastAPI application that serves a prediction endpoint for a simple feed-forward neural network trained on the MNIST dataset.
The prediction endpoint accepts an image file, preprocesses the image, and returns a prediction of the digit in the image.
"""

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, UploadFile, File
# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library based on the Torch library
import torch
# torchvision is a library for PyTorch that provides datasets and models for computer vision
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Importing Net class from main.py which defines the neural network architecture
from main import Net

# Load the model
# Load the model. Net class defines the neural network architecture. load_state_dict loads a model’s parameter dictionary using a deserialized state_dict. eval sets the module in evaluation mode.
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# Transform used for preprocessing the image. transforms.Compose composes several transforms together. ToTensor converts a PIL Image or numpy.ndarray to tensor, and Normalize normalizes a tensor image with mean and standard deviation.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -19,6 +29,17 @@

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function accepts an image file, preprocesses the image, and returns a prediction of the digit in the image.
The image is opened and converted to grayscale, then transformed to a tensor and normalized.
The image is then passed through the model to get a prediction.

Parameters:
file (UploadFile): The image file to predict.

Returns:
dict: A dictionary with the key 'prediction' and the predicted digit as the value.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
Expand Down
27 changes: 27 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,56 @@
"""
This module contains code for training a simple feed-forward neural network on the MNIST dataset using PyTorch.
The MNIST dataset consists of 28x28 grayscale images of handwritten digits.
The PyTorch model defined in this module is a simple feed-forward neural network with two hidden layers.
The data is preprocessed by converting the images to tensors and normalizing them.
"""

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library based on the Torch library
import torch
# torch.nn is a sublibrary of PyTorch, provides classes to build neural networks
import torch.nn as nn
# torch.optim is a package implementing various optimization algorithms
import torch.optim as optim
# torchvision is a library for PyTorch that provides datasets and models for computer vision
from torchvision import datasets, transforms
# DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset
from torch.utils.data import DataLoader
# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# transforms.Compose composes several transforms together, ToTensor converts a PIL Image or numpy.ndarray to tensor,
# and Normalize normalizes a tensor image with mean and standard deviation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# datasets.MNIST loads the MNIST dataset from root directory with train/test argument, applies the transformations and downloads the data if not available
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
# DataLoader wraps an iterable around the Dataset to enable easy access to the samples
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network with two hidden layers.
The network takes as input a 28x28 image, flattens it into a 784-dimensional vector,
and then passes it through two hidden layers with ReLU activation,
followed by an output layer with log softmax activation.
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
Defines the forward pass of the network.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
Expand Down