From c440087c8265dbd5a69b7c8e77ba71913762f101 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:40:51 +0000 Subject: [PATCH 1/3] feat: Updated src/main.py --- src/main.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..7702210 100644 --- a/src/main.py +++ b/src/main.py @@ -1,29 +1,54 @@ -from PIL import Image +""" +This script loads and preprocesses the MNIST dataset, defines a PyTorch model, and trains the model. +""" +# PIL is used for opening, manipulating, and saving many different image file formats +import PIL.Image +# torch is the main PyTorch library import torch +# torch.nn provides classes for building neural networks import torch.nn as nn +# torch.optim provides classes for implementing various optimization algorithms import torch.optim as optim -from torchvision import datasets, transforms +# torchvision.datasets provides classes for loading and using various popular datasets +from torchvision import datasets +# torchvision.transforms provides classes for transforming images +from torchvision import transforms +# torch.utils.data provides classes for loading data in parallel from torch.utils.data import DataLoader +# numpy is used for numerical operations import numpy as np # Step 1: Load MNIST Data and Preprocess +# This is a sequence of preprocessing steps to be applied to the images in the MNIST dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# This represents the MNIST training dataset trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +# This is a data loader for batching and shuffling the training data trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model +# This class defines the architecture of the PyTorch model class Net(nn.Module): + """ + This class defines the architecture of the PyTorch model. + """ def __init__(self): + """ + This method initializes the model. + """ super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): + """ + This method defines the forward pass of the model. + """ x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) From 913547b195c26a31f07905f5d9e6a7edd67503d4 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:50:25 +0000 Subject: [PATCH 2/3] feat: Updated src/api.py --- src/api.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..43e2ee4 100644 --- a/src/api.py +++ b/src/api.py @@ -1,24 +1,44 @@ +""" +This script creates a FastAPI application for making predictions using the PyTorch model defined in main.py. +""" +# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. from fastapi import FastAPI, UploadFile, File +# PIL is used for opening, manipulating, and saving many different image file formats from PIL import Image +# torch is the main PyTorch library import torch +# torchvision.transforms provides classes for transforming images from torchvision import transforms -from main import Net # Importing Net class from main.py +# Importing Net class from main.py +from main import Net -# Load the model +# Load the trained PyTorch model from a file model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image +# Define a sequence of preprocessing steps to be applied to the input images transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# Create an instance of the FastAPI application app = FastAPI() +# Define a route handler for making predictions using the model @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function is a route handler for making predictions using the model. + It takes an image file as input and returns a prediction. + + Parameters: + file (UploadFile): The image file to predict. + + Returns: + dict: A dictionary with the prediction. + """ image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension From c758ee761529c3ccd87263349d5f85a81c5695b0 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:53:02 +0000 Subject: [PATCH 3/3] Sandbox run src/api.py --- src/api.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/api.py b/src/api.py index 43e2ee4..475b5c0 100644 --- a/src/api.py +++ b/src/api.py @@ -1,14 +1,18 @@ """ This script creates a FastAPI application for making predictions using the PyTorch model defined in main.py. """ +# torch is the main PyTorch library +import torch + # FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. -from fastapi import FastAPI, UploadFile, File +from fastapi import FastAPI, File, UploadFile + # PIL is used for opening, manipulating, and saving many different image file formats from PIL import Image -# torch is the main PyTorch library -import torch + # torchvision.transforms provides classes for transforming images from torchvision import transforms + # Importing Net class from main.py from main import Net @@ -18,14 +22,14 @@ model.eval() # Define a sequence of preprocessing steps to be applied to the input images -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) # Create an instance of the FastAPI application app = FastAPI() + # Define a route handler for making predictions using the model @app.post("/predict/") async def predict(file: UploadFile = File(...)):