From 6eb6a09b38963aaee67dcba98bb2a55196f36c69 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:15:13 +0000 Subject: [PATCH 1/2] feat: Updated src/main.py --- src/main.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/main.py b/src/main.py index 243a31e..13759b1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,22 +1,42 @@ +""" +This module is used to load and preprocess the MNIST dataset and define a PyTorch model. +""" + +# PIL is used for opening, manipulating, and saving many different image file formats from PIL import Image +# torch is the main PyTorch library import torch +# torch.nn provides classes for building neural networks import torch.nn as nn +# torch.optim provides classes for implementing various optimization algorithms import torch.optim as optim +# torchvision.datasets provides classes for loading and manipulating datasets from torchvision import datasets, transforms +# torchvision.transforms provides classes for transforming images +# torch.utils.data provides utilities for loading and manipulating data from torch.utils.data import DataLoader +# numpy is used for numerical operations import numpy as np # Step 1: Load MNIST Data and Preprocess +# 'transform' is a sequence of transformations applied to the images in the dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# 'trainset' represents the MNIST dataset trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +# 'trainloader' is a data loader for the MNIST dataset trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model class Net(nn.Module): + """ + This class defines a simple feed-forward neural network for classifying MNIST images. + The network has three fully connected layers and uses ReLU activation functions. + The forward method takes an input tensor, reshapes it, and passes it through the network. + """ def __init__(self): super().__init__() self.fc1 = nn.Linear(28 * 28, 128) From 520f1b6e1b424afa21dc18c119d6d7000a141c2b Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:25:16 +0000 Subject: [PATCH 2/2] feat: Updated src/api.py --- src/api.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..c4a4b2d 100644 --- a/src/api.py +++ b/src/api.py @@ -1,24 +1,37 @@ +""" +This module creates a FastAPI application for making predictions using the PyTorch model defined in main.py. +""" +# FastAPI is used for creating the API from fastapi import FastAPI, UploadFile, File +# PIL is used for opening, manipulating, and saving many different image file formats from PIL import Image +# torch is the main PyTorch library import torch +# torchvision.transforms provides classes for transforming images from torchvision import transforms -from main import Net # Importing Net class from main.py +# Net is the PyTorch model defined in main.py +from main import Net -# Load the model +# 'model' represents the PyTorch model model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image +# 'transform' is a sequence of transformations applied to the images transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# 'app' represents the FastAPI application app = FastAPI() @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function takes an image file as input, preprocesses it, passes it through the model, and returns the model's prediction. + The input is an image file and the return value is a dictionary with the key 'prediction' and the model's prediction as the value. + """ image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension