diff --git a/src/api.py b/src/api.py index 36c257a..0cc1a3d 100644 --- a/src/api.py +++ b/src/api.py @@ -1,28 +1,47 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image +""" +This script defines a FastAPI application that uses the PyTorch model defined in main.py to make predictions on uploaded images. +""" + import torch +from fastapi import FastAPI, File, UploadFile +from PIL import Image from torchvision import transforms + from main import Net # Importing Net class from main.py # Load the model +# The model is loaded from the saved state dictionary and set to evaluation mode with model.eval() model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# The transformation pipeline consists of two steps: +# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor. +# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5. +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function takes an uploaded file as input, preprocesses the image, and makes a prediction using the PyTorch model. + The prediction is returned as a dictionary with the key 'prediction'. + """ + # Open the image file and convert it to grayscale image = Image.open(file.file).convert("L") + # Apply the transformation pipeline to the image image = transform(image) + # Add a batch dimension to the image tensor image = image.unsqueeze(0) # Add batch dimension + # Make a prediction with the model with torch.no_grad(): output = model(image) + # Get the class with the highest probability _, predicted = torch.max(output.data, 1) + # Return the prediction as a dictionary return {"prediction": int(predicted[0])} diff --git a/src/main.py b/src/main.py index 243a31e..712189e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,28 +1,52 @@ -from PIL import Image +""" +This script defines the data loading and preprocessing steps, as well as the PyTorch model for MNIST digit classification. +""" + +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# The transformation pipeline consists of two steps: +# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor. +# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5. +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model +""" +This class defines the PyTorch model for MNIST digit classification. +The model consists of three fully connected layers. +""" + + class Net(nn.Module): def __init__(self): super().__init__() + # First fully connected layer, takes input of size 28*28 and outputs size 128. self.fc1 = nn.Linear(28 * 28, 128) + # Second fully connected layer, takes input of size 128 and outputs size 64. self.fc2 = nn.Linear(128, 64) + # Third fully connected layer, takes input of size 64 and outputs size 10 (for 10 digit classes). self.fc3 = nn.Linear(64, 10) - + + """ + This method defines the forward pass of the model. + It applies the following transformations to the input: + 1. Reshapes the input to a 1D tensor. + 2. Applies the first fully connected layer followed by a ReLU activation function. + 3. Applies the second fully connected layer followed by a ReLU activation function. + 4. Applies the third fully connected layer. + 5. Applies a log softmax function to the output of the third layer. + """ + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) @@ -30,6 +54,7 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) @@ -45,4 +70,4 @@ def forward(self, x): loss.backward() optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")