Skip to content

Add Comments and Docstrings to main.py and api.py #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This module provides an API endpoint for predicting the digit in an uploaded image using a pre-trained PyTorch model.

The API endpoint '/predict/' accepts POST requests with an image file, preprocesses the image, and returns the predicted digit.
"""
import torch
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile
from main import Net # Importing Net class from main.py
from PIL import Image
from torchvision import transforms

# Load the model
model = Net()
Expand All @@ -26,3 +31,20 @@ async def predict(file: UploadFile = File(...)):
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
return {"prediction": int(predicted[0])}
Parameters:
- file (UploadFile): The image file to predict.

Returns:
- dict: A dictionary with the key 'prediction' and the predicted digit as the value.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
19 changes: 16 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
Expand All @@ -17,6 +17,19 @@

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This is a simple feed-forward neural network model.

Methods:
- __init__: Initializes the model layers.
- forward: Defines the forward pass of the model.

Parameters for forward:
- x (torch.Tensor): The input tensor.

Returns from forward:
- torch.Tensor: The output tensor after passing through the model.
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
Expand Down