Skip to content

Commit

Permalink
feat: add CNN class for MNIST classification
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 13, 2023
1 parent a7d4e3a commit 3a75d53
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

class CNN(nn.Module):
"""
Convolutional Neural Network for MNIST classification.
"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(7*7*64, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 7*7*64)
x = self.relu3(self.fc1(x))
return self.fc2(x)

def train(self, trainloader, lr, epochs):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(self.parameters(), lr=lr)

for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

def save_model(self, file_path):
torch.save(self.state_dict(), file_path)

0 comments on commit 3a75d53

Please sign in to comment.