Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #111

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This module defines a FastAPI application that serves a PyTorch model trained on the MNIST dataset.
It includes the necessary imports, model loading, image preprocessing, and prediction endpoint.
"""
import torch
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms

from main import Net # Importing Net class from main.py

# Load the model
# Load the trained PyTorch model from the saved state dictionary
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define the transformations to be applied to the images for preprocessing
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

app = FastAPI()


@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function takes an uploaded image file, preprocesses it, and makes a prediction using the loaded PyTorch model.

Parameters:
file: The uploaded image file.

Returns:
A dictionary with the key 'prediction' and the predicted digit as the value.
"""
# Open the image file and convert it to grayscale
image = Image.open(file.file).convert("L")
# Apply the defined transformations to the image
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
# Add a batch dimension to the image tensor
image = image.unsqueeze(0)
# Make a prediction with the model without computing gradients
with torch.no_grad():
output = model(image)
# Get the digit with the highest prediction score
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
58 changes: 48 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,86 @@
from PIL import Image
"""
This module is used to train a PyTorch model on the MNIST dataset.
It includes the necessary imports, data loading and preprocessing, model definition, and training loop.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define the transformations to be applied to the images.
# The images are converted to tensors and normalized.
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network for classifying MNIST images.
It inherits from the nn.Module class of PyTorch.

Attributes:
fc1: First fully connected layer, from input size to 128 nodes.
fc2: Second fully connected layer, from 128 to 64 nodes.
fc3: Third fully connected layer, from 64 to 10 nodes (output layer).
"""

def __init__(self):
"""
Initializes the network by defining its layers.
"""
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
Defines the forward pass of the network.

Parameters:
x: The input tensor.

Returns:
The output tensor after passing through the network.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Training loop
# Define the number of training epochs
epochs = 3

# Start the training loop
for epoch in range(epochs):
# For each batch of images and labels in the trainloader
for images, labels in trainloader:
# Zero the gradients
optimizer.zero_grad()
# Forward pass: compute the output of the model on the images
output = model(images)
# Compute the loss between the output and the true labels
loss = criterion(output, labels)
# Backward pass: compute the gradients of the loss with respect to the model parameters
loss.backward()
# Update the model parameters
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")