Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #42

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
"""
This module is used to create a FastAPI application that serves a machine learning model trained on the MNIST dataset.
"""

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, UploadFile, File
# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library
import torch
# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Importing Net class from main.py
from main import Net

# Load the model
# Load the model. This is the trained neural network model that we will use for making predictions.
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# Transform used for preprocessing the image. The transforms are converting the data to PyTorch tensors and normalizing the pixel values.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# FastAPI application. This is the main entry point for our API.
app = FastAPI()

# This function is used to make predictions on uploaded images.
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
Make a prediction on an uploaded image.

Parameters
----------
file : UploadFile
The image file to make a prediction on.

Returns
-------
dict
A dictionary with a single key 'prediction' and the predicted digit as the value.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
Expand Down
41 changes: 41 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,70 @@
"""
This module is used to train a simple neural network model on the MNIST dataset using PyTorch.
"""

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library
import torch
# torch.nn is a sublibrary of PyTorch, used for building and training neural networks
import torch.nn as nn
# torch.optim is a package implementing various optimization algorithms
import torch.optim as optim
# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import datasets, transforms
# DataLoader is a PyTorch function for loading and managing datasets
from torch.utils.data import DataLoader
# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# The transforms are converting the data to PyTorch tensors and normalizing the pixel values
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# The trainset is the MNIST dataset that we will train our model on
# The trainloader is a DataLoader which provides functionality for batching, shuffling and loading data in parallel
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
# The Net class is a definition of a simple neural network
class Net(nn.Module):
"""
A simple neural network with one hidden layer.

...

Methods
-------
forward(x):
Defines the computation performed at every call.
"""
def __init__(self):
"""
Initializes the neural network with one hidden layer.
"""
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
Defines the computation performed at every call.

Parameters
----------
x : Tensor
The input data.

Returns
-------
Tensor
The output of the neural network.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
Expand Down