Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #135

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This script defines a FastAPI application that uses the PyTorch model defined in main.py to make predictions on uploaded images.
"""

import torch
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms

from main import Net # Importing Net class from main.py

# Load the model
# The model is loaded from the saved state dictionary and set to evaluation mode with model.eval()
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# The transformation pipeline consists of two steps:
# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor.
# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5.
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

app = FastAPI()


@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function takes an uploaded file as input, preprocesses the image, and makes a prediction using the PyTorch model.
The prediction is returned as a dictionary with the key 'prediction'.
"""
# Open the image file and convert it to grayscale
image = Image.open(file.file).convert("L")
# Apply the transformation pipeline to the image
image = transform(image)
# Add a batch dimension to the image tensor
image = image.unsqueeze(0) # Add batch dimension
# Make a prediction with the model
with torch.no_grad():
output = model(image)
# Get the class with the highest probability
_, predicted = torch.max(output.data, 1)
# Return the prediction as a dictionary
return {"prediction": int(predicted[0])}
45 changes: 35 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,60 @@
from PIL import Image
"""
This script defines the data loading and preprocessing steps, as well as the PyTorch model for MNIST digit classification.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# The transformation pipeline consists of two steps:
# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor.
# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5.
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
"""
This class defines the PyTorch model for MNIST digit classification.
The model consists of three fully connected layers.
"""


class Net(nn.Module):
def __init__(self):
super().__init__()
# First fully connected layer, takes input of size 28*28 and outputs size 128.
self.fc1 = nn.Linear(28 * 28, 128)
# Second fully connected layer, takes input of size 128 and outputs size 64.
self.fc2 = nn.Linear(128, 64)
# Third fully connected layer, takes input of size 64 and outputs size 10 (for 10 digit classes).
self.fc3 = nn.Linear(64, 10)


"""
This method defines the forward pass of the model.
It applies the following transformations to the input:
1. Reshapes the input to a 1D tensor.
2. Applies the first fully connected layer followed by a ReLU activation function.
3. Applies the second fully connected layer followed by a ReLU activation function.
4. Applies the third fully connected layer.
5. Applies a log softmax function to the output of the third layer.
"""

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -45,4 +70,4 @@ def forward(self, x):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")