Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #60

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
"""
This script is used to create a FastAPI application for making predictions using the PyTorch model defined in main.py.
"""

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, UploadFile, File
# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library based on the Torch library
import torch
# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Importing Net class from main.py
from main import Net

# Load the model
# Load the trained PyTorch model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# Define the transformations to be applied to the images that are uploaded to the FastAPI application
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# Create the FastAPI application
app = FastAPI()

# Route handler for the '/predict' endpoint of the FastAPI application
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# Load the uploaded image
image = Image.open(file.file).convert("L")
# Apply the transformations
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
# Add a batch dimension
image = image.unsqueeze(0)
with torch.no_grad():
# Make the prediction
output = model(image)
_, predicted = torch.max(output.data, 1)
# Return the prediction
return {"prediction": int(predicted[0])}
14 changes: 14 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
"""
This script is used to load and preprocess the MNIST dataset, define a PyTorch model, and train the model.
"""

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# PyTorch is an open source machine learning library based on the Torch library
import torch
# torch.nn is a sublibrary of PyTorch, used for building neural networks
import torch.nn as nn
# torch.optim is a package implementing various optimization algorithms
import torch.optim as optim
# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import datasets, transforms
# DataLoader is a PyTorch function for loading and representing data
from torch.utils.data import DataLoader
# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# 'transform' is used to define the transformations to be applied to the images in the dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 'trainset' and 'trainloader' are used to load and preprocess the training data
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
# 'Net' class is used to define the structure of the neural network
class Net(nn.Module):
def __init__(self):
super().__init__()
Expand Down