diff --git a/src/api.py b/src/api.py
index 36c257a..4a1dc79 100644
--- a/src/api.py
+++ b/src/api.py
@@ -1,28 +1,55 @@
-from fastapi import FastAPI, UploadFile, File
-from PIL import Image
+"""
+This file loads the model, preprocesses the image, and makes predictions using the FastAPI framework.
+"""
+
+# PyTorch is an open source machine learning library based on the Torch library
 import torch
+
+# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
+from fastapi import FastAPI, File, UploadFile
+
+# PIL is used for opening, manipulating, and saving many different image file formats
+from PIL import Image
+
+# torchvision is a library for PyTorch that provides datasets and models for computer vision
 from torchvision import transforms
-from main import Net  # Importing Net class from main.py
 
-# Load the model
+# Importing Net class from main.py
+from main import Net
+
+# Load the model and set it to evaluation mode
 model = Net()
 model.load_state_dict(torch.load("mnist_model.pth"))
 model.eval()
 
-# Transform used for preprocessing the image
-transform = transforms.Compose([
-    transforms.ToTensor(),
-    transforms.Normalize((0.5,), (0.5,))
-])
+# Transform used for preprocessing the image. It converts the image to tensor and normalizes it.
+transform = transforms.Compose(
+    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
+)
 
+# FastAPI application
 app = FastAPI()
 
+
 @app.post("/predict/")
 async def predict(file: UploadFile = File(...)):
+    """
+    This function takes an image file as input, preprocesses the image, and makes a prediction using the loaded model.
+    Parameters:
+    file (UploadFile): The image file to be processed and predicted.
+    Returns:
+    dict: A dictionary with the prediction result.
+    """
+    # Open the image file and convert it to grayscale
     image = Image.open(file.file).convert("L")
+    # Apply the transform to the image
     image = transform(image)
-    image = image.unsqueeze(0)  # Add batch dimension
+    # Add a batch dimension to the image
+    image = image.unsqueeze(0)
+    # Make a prediction without calculating gradients
     with torch.no_grad():
         output = model(image)
+        # Get the index of the max log-probability
         _, predicted = torch.max(output.data, 1)
+    # Return the prediction result
     return {"prediction": int(predicted[0])}
diff --git a/src/main.py b/src/main.py
index 243a31e..8285f4b 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,28 +1,53 @@
-from PIL import Image
+"""
+This file loads and preprocesses the MNIST dataset and defines a PyTorch model.
+"""
+
+# NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices
+import numpy as np
+
+# PyTorch is an open source machine learning library based on the Torch library
 import torch
+
+# torch.nn is a sublibrary of PyTorch, provides classes to build neural networks
 import torch.nn as nn
+
+# torch.optim is a package implementing various optimization algorithms
 import torch.optim as optim
-from torchvision import datasets, transforms
+
+# DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset
 from torch.utils.data import DataLoader
-import numpy as np
+
+# torchvision is a library for PyTorch that provides datasets and models for computer vision
+from torchvision import datasets, transforms
+
+# PIL is used for opening, manipulating, and saving many different image file formats
+
 
 # Step 1: Load MNIST Data and Preprocess
-transform = transforms.Compose([
-    transforms.ToTensor(),
-    transforms.Normalize((0.5,), (0.5,))
-])
+# transforms.Compose composes several transforms together, in this case, toTensor and Normalize
+transform = transforms.Compose(
+    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
+)
 
-trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
+# datasets.MNIST loads the MNIST dataset from the root directory with the specified transforms
+trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
+# DataLoader combines the dataset and a sampler, and provides an iterable over the given dataset
 trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
 
+
 # Step 2: Define the PyTorch Model
 class Net(nn.Module):
+    """
+    This class defines a simple feed-forward neural network for the MNIST dataset.
+    It has three fully connected layers and uses ReLU activation function for the first two layers and log softmax for the output layer.
+    """
+
     def __init__(self):
         super().__init__()
         self.fc1 = nn.Linear(28 * 28, 128)
         self.fc2 = nn.Linear(128, 64)
         self.fc3 = nn.Linear(64, 10)
-    
+
     def forward(self, x):
         x = x.view(-1, 28 * 28)
         x = nn.functional.relu(self.fc1(x))
@@ -30,6 +55,7 @@ def forward(self, x):
         x = self.fc3(x)
         return nn.functional.log_softmax(x, dim=1)
 
+
 # Step 3: Train the Model
 model = Net()
 optimizer = optim.SGD(model.parameters(), lr=0.01)
@@ -45,4 +71,4 @@ def forward(self, x):
         loss.backward()
         optimizer.step()
 
-torch.save(model.state_dict(), "mnist_model.pth")
\ No newline at end of file
+torch.save(model.state_dict(), "mnist_model.pth")