diff --git a/src/api.py b/src/api.py
index 36c257a..0cc1a3d 100644
--- a/src/api.py
+++ b/src/api.py
@@ -1,28 +1,47 @@
-from fastapi import FastAPI, UploadFile, File
-from PIL import Image
+"""
+This script defines a FastAPI application that uses the PyTorch model defined in main.py to make predictions on uploaded images.
+"""
+
 import torch
+from fastapi import FastAPI, File, UploadFile
+from PIL import Image
 from torchvision import transforms
+
 from main import Net  # Importing Net class from main.py
 
 # Load the model
+# The model is loaded from the saved state dictionary and set to evaluation mode with model.eval()
 model = Net()
 model.load_state_dict(torch.load("mnist_model.pth"))
 model.eval()
 
 # Transform used for preprocessing the image
-transform = transforms.Compose([
-    transforms.ToTensor(),
-    transforms.Normalize((0.5,), (0.5,))
-])
+# The transformation pipeline consists of two steps:
+# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor.
+# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5.
+transform = transforms.Compose(
+    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
+)
 
 app = FastAPI()
 
+
 @app.post("/predict/")
 async def predict(file: UploadFile = File(...)):
+    """
+    This function takes an uploaded file as input, preprocesses the image, and makes a prediction using the PyTorch model.
+    The prediction is returned as a dictionary with the key 'prediction'.
+    """
+    # Open the image file and convert it to grayscale
     image = Image.open(file.file).convert("L")
+    # Apply the transformation pipeline to the image
     image = transform(image)
+    # Add a batch dimension to the image tensor
     image = image.unsqueeze(0)  # Add batch dimension
+    # Make a prediction with the model
     with torch.no_grad():
         output = model(image)
+        # Get the class with the highest probability
         _, predicted = torch.max(output.data, 1)
+    # Return the prediction as a dictionary
     return {"prediction": int(predicted[0])}
diff --git a/src/main.py b/src/main.py
index 243a31e..712189e 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,28 +1,52 @@
-from PIL import Image
+"""
+This script defines the data loading and preprocessing steps, as well as the PyTorch model for MNIST digit classification.
+"""
+
+import numpy as np
 import torch
 import torch.nn as nn
 import torch.optim as optim
-from torchvision import datasets, transforms
 from torch.utils.data import DataLoader
-import numpy as np
+from torchvision import datasets, transforms
 
 # Step 1: Load MNIST Data and Preprocess
-transform = transforms.Compose([
-    transforms.ToTensor(),
-    transforms.Normalize((0.5,), (0.5,))
-])
+# The transformation pipeline consists of two steps:
+# 1. transforms.ToTensor() - Converts the input image to PyTorch tensor.
+# 2. transforms.Normalize((0.5,), (0.5,)) - Normalizes the tensor with mean 0.5 and standard deviation 0.5.
+transform = transforms.Compose(
+    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
+)
 
-trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
+trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
 trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
 
 # Step 2: Define the PyTorch Model
+"""
+This class defines the PyTorch model for MNIST digit classification.
+The model consists of three fully connected layers.
+"""
+
+
 class Net(nn.Module):
     def __init__(self):
         super().__init__()
+        # First fully connected layer, takes input of size 28*28 and outputs size 128.
         self.fc1 = nn.Linear(28 * 28, 128)
+        # Second fully connected layer, takes input of size 128 and outputs size 64.
         self.fc2 = nn.Linear(128, 64)
+        # Third fully connected layer, takes input of size 64 and outputs size 10 (for 10 digit classes).
         self.fc3 = nn.Linear(64, 10)
-    
+
+    """
+    This method defines the forward pass of the model.
+    It applies the following transformations to the input:
+    1. Reshapes the input to a 1D tensor.
+    2. Applies the first fully connected layer followed by a ReLU activation function.
+    3. Applies the second fully connected layer followed by a ReLU activation function.
+    4. Applies the third fully connected layer.
+    5. Applies a log softmax function to the output of the third layer.
+    """
+
     def forward(self, x):
         x = x.view(-1, 28 * 28)
         x = nn.functional.relu(self.fc1(x))
@@ -30,6 +54,7 @@ def forward(self, x):
         x = self.fc3(x)
         return nn.functional.log_softmax(x, dim=1)
 
+
 # Step 3: Train the Model
 model = Net()
 optimizer = optim.SGD(model.parameters(), lr=0.01)
@@ -45,4 +70,4 @@ def forward(self, x):
         loss.backward()
         optimizer.step()
 
-torch.save(model.state_dict(), "mnist_model.pth")
\ No newline at end of file
+torch.save(model.state_dict(), "mnist_model.pth")