Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #102

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
"""
This module creates a FastAPI application for making predictions using the PyTorch model defined in main.py.
"""
# FastAPI is used for creating the API
from fastapi import FastAPI, UploadFile, File
# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# torch is the main PyTorch library
import torch
# torchvision.transforms provides classes for transforming images
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Net is the PyTorch model defined in main.py
from main import Net

# Load the model
# 'model' represents the PyTorch model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# 'transform' is a sequence of transformations applied to the images
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 'app' represents the FastAPI application
app = FastAPI()

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function takes an image file as input, preprocesses it, passes it through the model, and returns the model's prediction.
The input is an image file and the return value is a dictionary with the key 'prediction' and the model's prediction as the value.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
Expand Down
20 changes: 20 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
"""
This module is used to load and preprocess the MNIST dataset and define a PyTorch model.
"""

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image
# torch is the main PyTorch library
import torch
# torch.nn provides classes for building neural networks
import torch.nn as nn
# torch.optim provides classes for implementing various optimization algorithms
import torch.optim as optim
# torchvision.datasets provides classes for loading and manipulating datasets
from torchvision import datasets, transforms
# torchvision.transforms provides classes for transforming images
# torch.utils.data provides utilities for loading and manipulating data
from torch.utils.data import DataLoader
# numpy is used for numerical operations
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# 'transform' is a sequence of transformations applied to the images in the dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 'trainset' represents the MNIST dataset
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
# 'trainloader' is a data loader for the MNIST dataset
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network for classifying MNIST images.
The network has three fully connected layers and uses ReLU activation functions.
The forward method takes an input tensor, reshapes it, and passes it through the network.
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
Expand Down