Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #122

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""
This module is used to serve a PyTorch model as a FastAPI service. It includes the necessary steps to load the model,
preprocess the input image, and make a prediction.
"""

from fastapi import FastAPI, UploadFile, File
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py

# Load the model
# The model is loaded from the saved state dictionary and set to evaluation mode.
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# The transform is used to convert the input image to a tensor and normalize it.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -19,10 +26,29 @@

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function takes an uploaded file as input, preprocesses the image, makes a prediction using the model,
and returns the prediction.

Parameters:
file: The uploaded file containing the image.

Returns:
A dictionary containing the prediction made by the model.
"""
# Open the image file and convert it to grayscale
image = Image.open(file.file).convert("L")

# Preprocess the image using the transform
image = transform(image)

# Add a batch dimension to the image
image = image.unsqueeze(0) # Add batch dimension

# Make a prediction using the model
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)

# Return the prediction
return {"prediction": int(predicted[0])}
61 changes: 50 additions & 11 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,87 @@
from PIL import Image
"""
This module is used to train a PyTorch model on the MNIST dataset. It includes the necessary steps to preprocess the data,
define the model, and train the model.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# The transform variable is used to preprocess the MNIST data by converting the images to tensors and normalizing them.
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network for the MNIST dataset. It includes three fully connected layers.

Attributes:
fc1: The first fully connected layer.
fc2: The second fully connected layer.
fc3: The third fully connected layer.
"""

def __init__(self):
"""
Initializes the Net class by defining the three fully connected layers.
"""
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
Defines the forward pass of the network.

Parameters:
x: The input tensor.

Returns:
The output tensor after passing through the network.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Training loop
# Define the number of training epochs
epochs = 3
for epoch in range(epochs):

# Start the training loop
for _ in range(epochs):
# For each batch of images and labels in the trainloader
for images, labels in trainloader:
# Zero the gradients
optimizer.zero_grad()

# Forward pass: compute the output of the model on the images
output = model(images)

# Compute the loss between the output and the labels
loss = criterion(output, labels)

# Backward pass: compute the gradients of the loss with respect to the model parameters
loss.backward()

# Update the model parameters
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")