Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #64

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""
This module is used to create a FastAPI application that serves a model trained on the MNIST dataset.
The application provides an endpoint that accepts an image file, preprocesses the image, and returns a prediction of the digit in the image.
"""

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, UploadFile, File
# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library
import torch
# torchvision is a library for image and video processing
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Importing Net class from main.py
from main import Net

# Load the model
# Load the model trained on the MNIST dataset
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# Transform used for preprocessing the image. It converts the image to a tensor and normalizes it.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -19,10 +29,25 @@

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function accepts an image file, preprocesses the image, and returns a prediction of the digit in the image.

Parameters:
file (UploadFile): The image file to be processed.

Returns:
dict: A dictionary with the key 'prediction' and the predicted digit as the value.
"""
# Open the image file and convert it to grayscale
image = Image.open(file.file).convert("L")
# Preprocess the image
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
# Add batch dimension
image = image.unsqueeze(0)
with torch.no_grad():
# Compute the output of the model on the input image
output = model(image)
# Get the predicted class
_, predicted = torch.max(output.data, 1)
# Return the prediction
return {"prediction": int(predicted[0])}
30 changes: 30 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
"""
This module is used to train a simple neural network on the MNIST dataset using PyTorch.
"""

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library
import torch
# nn is a module of PyTorch that provides classes for building neural networks
import torch.nn as nn
# optim is a module that implements various optimization algorithms
import torch.optim as optim
# torchvision is a library for image and video processing
from torchvision import datasets, transforms
# DataLoader is a utility class that provides the ability to batch, shuffle and load data in parallel
from torch.utils.data import DataLoader
# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# The transforms are used to preprocess the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# trainset represents the MNIST dataset
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
# trainloader is a DataLoader instance that provides the ability to batch, shuffle and load the trainset in parallel
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class represents a simple neural network with one hidden layer.

The forward method implements the forward pass of the neural network.
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
# Flatten the input tensor
x = x.view(-1, 28 * 28)
# Apply the first fully connected layer and the ReLU activation function
x = nn.functional.relu(self.fc1(x))
# Apply the second fully connected layer and the ReLU activation function
x = nn.functional.relu(self.fc2(x))
# Apply the third fully connected layer
x = self.fc3(x)
# Apply the log softmax activation function
return nn.functional.log_softmax(x, dim=1)

# Step 3: Train the Model
Expand All @@ -38,11 +62,17 @@ def forward(self, x):
# Training loop
epochs = 3
for epoch in range(epochs):
# Iterate over the batches of images and labels
for images, labels in trainloader:
# Zero the gradients
optimizer.zero_grad()
# Forward pass: compute the output of the model on the input images
output = model(images)
# Compute the loss of the output
loss = criterion(output, labels)
# Backward pass: compute the gradients of the loss with respect to the model's parameters
loss.backward()
# Update the model's parameters
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")