Skip to content

Commit 35f37cd

Browse files
feat: Updated src/main.py
1 parent 7284908 commit 35f37cd

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

src/main.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from PIL import Image
1+
import numpy as np
22
import torch
33
import torch.nn as nn
44
import torch.optim as optim
5-
from torchvision import datasets, transforms
5+
from PIL import Image
66
from torch.utils.data import DataLoader
7-
import numpy as np
7+
from torchvision import datasets, transforms
88

99
# Step 1: Load MNIST Data and Preprocess
1010
transform = transforms.Compose([
@@ -29,20 +29,29 @@ def forward(self, x):
2929
x = nn.functional.relu(self.fc2(x))
3030
x = self.fc3(x)
3131
return nn.functional.log_softmax(x, dim=1)
32+
class Trainer:
33+
def __init__(self, learning_rate, model_path):
34+
self.model = Net()
35+
self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate)
36+
self.criterion = nn.NLLLoss()
37+
self.model_path = model_path
3238

33-
# Step 3: Train the Model
34-
model = Net()
35-
optimizer = optim.SGD(model.parameters(), lr=0.01)
36-
criterion = nn.NLLLoss()
39+
def train(self, epochs):
40+
for epoch in range(epochs):
41+
for images, labels in trainloader:
42+
self.optimizer.zero_grad()
43+
output = self.model(images)
44+
loss = self.criterion(output, labels)
45+
loss.backward()
46+
self.optimizer.step()
47+
48+
def save_model(self):
49+
torch.save(self.model.state_dict(), self.model_path)
3750

38-
# Training loop
39-
epochs = 3
40-
for epoch in range(epochs):
41-
for images, labels in trainloader:
42-
optimizer.zero_grad()
43-
output = model(images)
44-
loss = criterion(output, labels)
45-
loss.backward()
46-
optimizer.step()
4751

48-
torch.save(model.state_dict(), "mnist_model.pth")
52+
# Step 3: Train the Model
53+
54+
# Now let's create a Trainer instance and train and save the model
55+
trainer = Trainer(learning_rate=0.01, model_path="mnist_model.pth")
56+
trainer.train(epochs=3)
57+
trainer.save_model()

0 commit comments

Comments
 (0)