Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #146

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
"""
This module defines a FastAPI application that serves a model prediction endpoint.
The endpoint accepts an image file, preprocesses the image, and returns a prediction from a pre-trained model.
"""

from fastapi import FastAPI, UploadFile, File
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py
from main import Net # Import the Net class from main.py

# Load the model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()
# Load the pre-trained model
model = Net() # Initialize the model
model.load_state_dict(torch.load("mnist_model.pth")) # Load the pre-trained weights
model.eval() # Set the model to evaluation mode

# Transform used for preprocessing the image
# Define the transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize the tensor
])

app = FastAPI()

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
"""
Predict the digit in an uploaded image.

Parameters
----------
file : UploadFile
The image file to predict.

Returns
-------
dict
A dictionary with a single key "prediction" and the predicted digit as the value.
"""
image = Image.open(file.file).convert("L") # Open the image file and convert it to grayscale
image = transform(image) # Apply the preprocessing transform
image = image.unsqueeze(0) # Add a batch dimension to the tensor
with torch.no_grad(): # Disable gradient computation
output = model(image) # Forward pass through the model
_, predicted = torch.max(output.data, 1) # Get the index of the max log-probability
return {"prediction": int(predicted[0])} # Return the prediction as a JSON response
50 changes: 40 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This module is used to train a simple neural network model on the MNIST dataset using PyTorch.
It includes steps for loading and preprocessing the dataset, defining the model, and training the model.
"""

from PIL import Image
import torch
import torch.nn as nn
Expand All @@ -7,28 +12,53 @@
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# Define the transformations to be applied on the images
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
transforms.ToTensor(), # Convert the images to tensors
transforms.Normalize((0.5,), (0.5,)) # Normalize the images
])

# Load the MNIST dataset and apply the transformations
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)

# Create a DataLoader for the dataset
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network model.

Methods
-------
forward(x)
Defines the forward pass of the model.
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
self.fc1 = nn.Linear(28 * 28, 128) # First fully connected layer
self.fc2 = nn.Linear(128, 64) # Second fully connected layer
self.fc3 = nn.Linear(64, 10) # Third fully connected layer

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)
"""
Defines the forward pass of the model.

Parameters
----------
x : torch.Tensor
The input tensor.

Returns
-------
torch.Tensor
The output tensor.
"""
x = x.view(-1, 28 * 28) # Flatten the input tensor
x = nn.functional.relu(self.fc1(x)) # Apply ReLU activation function after the first layer
x = nn.functional.relu(self.fc2(x)) # Apply ReLU activation function after the second layer
x = self.fc3(x) # Apply the third layer
return nn.functional.log_softmax(x, dim=1) # Apply log softmax to the output

# Step 3: Train the Model
model = Net()
Expand Down