From cd02ab29466b7bc1efa36e7e33ec974160ce6250 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:03:44 +0000 Subject: [PATCH 1/3] feat: Updated src/main.py --- src/main.py | 50 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..33c661a 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,8 @@ +""" +This module is used to train a simple neural network model on the MNIST dataset using PyTorch. +It includes steps for loading and preprocessing the dataset, defining the model, and training the model. +""" + from PIL import Image import torch import torch.nn as nn @@ -7,28 +12,53 @@ import numpy as np # Step 1: Load MNIST Data and Preprocess +# Define the transformations to be applied on the images transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) + transforms.ToTensor(), # Convert the images to tensors + transforms.Normalize((0.5,), (0.5,)) # Normalize the images ]) +# Load the MNIST dataset and apply the transformations trainset = datasets.MNIST('.', download=True, train=True, transform=transform) + +# Create a DataLoader for the dataset trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model class Net(nn.Module): + """ + This class defines a simple feed-forward neural network model. + + Methods + ------- + forward(x) + Defines the forward pass of the model. + """ def __init__(self): super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) + self.fc1 = nn.Linear(28 * 28, 128) # First fully connected layer + self.fc2 = nn.Linear(128, 64) # Second fully connected layer + self.fc3 = nn.Linear(64, 10) # Third fully connected layer def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) + """ + Defines the forward pass of the model. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The output tensor. + """ + x = x.view(-1, 28 * 28) # Flatten the input tensor + x = nn.functional.relu(self.fc1(x)) # Apply ReLU activation function after the first layer + x = nn.functional.relu(self.fc2(x)) # Apply ReLU activation function after the second layer + x = self.fc3(x) # Apply the third layer + return nn.functional.log_softmax(x, dim=1) # Apply log softmax to the output # Step 3: Train the Model model = Net() From 0f6c426bebf9acebdd34eb9c5a9758f9a5957330 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:04:22 +0000 Subject: [PATCH 2/3] feat: Updated requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9679557..2f2271f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ certifi==2022.12.7 charset-normalizer==2.1.1 click==8.1.7 dill==0.3. -distutils exceptiongroup==1.1.3 fastapi==0.104.0 filelock==3.9.0 From ad85e5be21a5e7c2fa523515ab73d27b431925d1 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:06:27 +0000 Subject: [PATCH 3/3] feat: Updated src/api.py --- src/api.py | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..e6562e2 100644 --- a/src/api.py +++ b/src/api.py @@ -1,28 +1,46 @@ +""" +This module defines a FastAPI application that serves a model prediction endpoint. +The endpoint accepts an image file, preprocesses the image, and returns a prediction from a pre-trained model. +""" + from fastapi import FastAPI, UploadFile, File from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import Net # Import the Net class from main.py -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +# Load the pre-trained model +model = Net() # Initialize the model +model.load_state_dict(torch.load("mnist_model.pth")) # Load the pre-trained weights +model.eval() # Set the model to evaluation mode -# Transform used for preprocessing the image +# Define the transform used for preprocessing the image transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) + transforms.ToTensor(), # Convert the image to a PyTorch tensor + transforms.Normalize((0.5,), (0.5,)) # Normalize the tensor ]) app = FastAPI() @app.post("/predict/") async def predict(file: UploadFile = File(...)): - image = Image.open(file.file).convert("L") - image = transform(image) - image = image.unsqueeze(0) # Add batch dimension - with torch.no_grad(): - output = model(image) - _, predicted = torch.max(output.data, 1) - return {"prediction": int(predicted[0])} + """ + Predict the digit in an uploaded image. + + Parameters + ---------- + file : UploadFile + The image file to predict. + + Returns + ------- + dict + A dictionary with a single key "prediction" and the predicted digit as the value. + """ + image = Image.open(file.file).convert("L") # Open the image file and convert it to grayscale + image = transform(image) # Apply the preprocessing transform + image = image.unsqueeze(0) # Add a batch dimension to the tensor + with torch.no_grad(): # Disable gradient computation + output = model(image) # Forward pass through the model + _, predicted = torch.max(output.data, 1) # Get the index of the max log-probability + return {"prediction": int(predicted[0])} # Return the prediction as a JSON response