diff --git a/src/api.py b/src/api.py index 36c257a..86e6782 100644 --- a/src/api.py +++ b/src/api.py @@ -1,24 +1,26 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image import torch +from fastapi import FastAPI, File, UploadFile +from PIL import Image from torchvision import transforms + from main import Net # Importing Net class from main.py -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +# Instantiate the trainer and load the model +trainer = MNISTTrainer() +model = trainer.load_model("mnist_model.pth") # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) app = FastAPI() + @app.post("/predict/") -async def predict(file: UploadFile = File(...)): +async def predict(file: UploadFile = None): + if file is None: + file = File(...) image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension diff --git a/src/main.py b/src/main.py index 243a31e..e7d4f35 100644 --- a/src/main.py +++ b/src/main.py @@ -1,48 +1,76 @@ -from PIL import Image import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np - -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +from torchvision import datasets, transforms -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) -# Step 2: Define the PyTorch Model -class Net(nn.Module): +class MNISTTrainer: def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) - -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() - -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() - -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file + self.transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] + ) + self.trainset = None + self.trainloader = None + self.model = None + self.optimizer = None + self.criterion = nn.NLLLoss() + self.epochs = 3 + + def load_data(self): + self.trainset = datasets.MNIST( + ".", download=True, train=True, transform=self.transform + ) + self.trainloader = DataLoader(self.trainset, batch_size=64, shuffle=True) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return nn.functional.log_softmax(x, dim=1) + + def define_model(self): + self.model = self.Net() + self.optimizer = optim.SGD(self.model.parameters(), lr=0.01) + + def train(self): + for epoch in range(self.epochs): + for images, labels in self.trainloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() + + def save_model(self, path="mnist_model.pth"): + torch.save(self.model.state_dict(), path) + + def evaluate_model(self, validation_loader): + correct = 0 + total = 0 + with torch.no_grad(): + for data in validation_loader: + images, labels = data + outputs = self.model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + print( + "Accuracy of the network on the validation images: %d %%" + % (100 * correct / total) + ) + + +trainer = MNISTTrainer() +trainer.load_data() +trainer.define_model() +trainer.train() +trainer.save_model()