From 75a9c677a75d618ad8519a60b096da327f02424f Mon Sep 17 00:00:00 2001
From: "sweep-nightly[bot]"
 <131841235+sweep-nightly[bot]@users.noreply.github.com>
Date: Fri, 13 Oct 2023 23:09:27 +0000
Subject: [PATCH 1/2] feat: Updated src/main.py

---
 src/main.py | 41 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 41 insertions(+)

diff --git a/src/main.py b/src/main.py
index 243a31e..b4293bd 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,29 +1,70 @@
+"""
+This module is used to train a simple neural network model on the MNIST dataset using PyTorch.
+"""
+
+# PIL is used for opening, manipulating, and saving many different image file formats
 from PIL import Image
+# PyTorch is an open source machine learning library
 import torch
+# torch.nn is a sublibrary of PyTorch, used for building and training neural networks
 import torch.nn as nn
+# torch.optim is a package implementing various optimization algorithms
 import torch.optim as optim
+# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
 from torchvision import datasets, transforms
+# DataLoader is a PyTorch function for loading and managing datasets
 from torch.utils.data import DataLoader
+# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays
 import numpy as np
 
 # Step 1: Load MNIST Data and Preprocess
+# The transforms are converting the data to PyTorch tensors and normalizing the pixel values
 transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))
 ])
 
+# The trainset is the MNIST dataset that we will train our model on
+# The trainloader is a DataLoader which provides functionality for batching, shuffling and loading data in parallel
 trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
 trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
 
 # Step 2: Define the PyTorch Model
+# The Net class is a definition of a simple neural network
 class Net(nn.Module):
+    """
+    A simple neural network with one hidden layer.
+
+    ...
+
+    Methods
+    -------
+    forward(x):
+        Defines the computation performed at every call.
+    """
     def __init__(self):
+        """
+        Initializes the neural network with one hidden layer.
+        """
         super().__init__()
         self.fc1 = nn.Linear(28 * 28, 128)
         self.fc2 = nn.Linear(128, 64)
         self.fc3 = nn.Linear(64, 10)
     
     def forward(self, x):
+        """
+        Defines the computation performed at every call.
+
+        Parameters
+        ----------
+        x : Tensor
+            The input data.
+
+        Returns
+        -------
+        Tensor
+            The output of the neural network.
+        """
         x = x.view(-1, 28 * 28)
         x = nn.functional.relu(self.fc1(x))
         x = nn.functional.relu(self.fc2(x))

From 831f6f074270642dd87e48474dfb8559353d551f Mon Sep 17 00:00:00 2001
From: "sweep-nightly[bot]"
 <131841235+sweep-nightly[bot]@users.noreply.github.com>
Date: Fri, 13 Oct 2023 23:10:55 +0000
Subject: [PATCH 2/2] feat: Updated src/api.py

---
 src/api.py | 30 +++++++++++++++++++++++++++---
 1 file changed, 27 insertions(+), 3 deletions(-)

diff --git a/src/api.py b/src/api.py
index 36c257a..19ceab0 100644
--- a/src/api.py
+++ b/src/api.py
@@ -1,24 +1,48 @@
+"""
+This module is used to create a FastAPI application that serves a machine learning model trained on the MNIST dataset.
+"""
+
+# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
 from fastapi import FastAPI, UploadFile, File
+# PIL is used for opening, manipulating, and saving many different image file formats
 from PIL import Image
+# PyTorch is an open source machine learning library
 import torch
+# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
 from torchvision import transforms
-from main import Net  # Importing Net class from main.py
+# Importing Net class from main.py
+from main import Net  
 
-# Load the model
+# Load the model. This is the trained neural network model that we will use for making predictions.
 model = Net()
 model.load_state_dict(torch.load("mnist_model.pth"))
 model.eval()
 
-# Transform used for preprocessing the image
+# Transform used for preprocessing the image. The transforms are converting the data to PyTorch tensors and normalizing the pixel values.
 transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))
 ])
 
+# FastAPI application. This is the main entry point for our API.
 app = FastAPI()
 
+# This function is used to make predictions on uploaded images.
 @app.post("/predict/")
 async def predict(file: UploadFile = File(...)):
+    """
+    Make a prediction on an uploaded image.
+
+    Parameters
+    ----------
+    file : UploadFile
+        The image file to make a prediction on.
+
+    Returns
+    -------
+    dict
+        A dictionary with a single key 'prediction' and the predicted digit as the value.
+    """
     image = Image.open(file.file).convert("L")
     image = transform(image)
     image = image.unsqueeze(0)  # Add batch dimension