Skip to content

Commit 7cfe902

Browse files
feat: Add CNN class with training method in cnn.py
1 parent 7284908 commit 7cfe902

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

src/cnn.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.utils.data import DataLoader
5+
from torchvision import datasets, transforms
6+
7+
8+
class CNN(nn.Module):
9+
def __init__(self):
10+
super(CNN, self).__init__()
11+
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
12+
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
13+
self.pool = nn.MaxPool2d(2, 2)
14+
self.fc1 = nn.Linear(64 * 4 * 4, 128)
15+
self.fc2 = nn.Linear(128, 10)
16+
17+
def forward(self, x):
18+
x = self.pool(nn.functional.relu(self.conv1(x)))
19+
x = self.pool(nn.functional.relu(self.conv2(x)))
20+
x = x.view(-1, 64 * 4 * 4)
21+
x = nn.functional.relu(self.fc1(x))
22+
x = self.fc2(x)
23+
return nn.functional.log_softmax(x, dim=1)
24+
25+
def train_cnn(self, trainloader, epochs=3):
26+
optimizer = optim.SGD(self.parameters(), lr=0.01)
27+
criterion = nn.NLLLoss()
28+
29+
for epoch in range(epochs):
30+
for images, labels in trainloader:
31+
optimizer.zero_grad()
32+
output = self(images)
33+
loss = criterion(output, labels)
34+
loss.backward()
35+
optimizer.step()
36+
37+
torch.save(self.state_dict(), "mnist_cnn_model.pth")
38+
39+
transform = transforms.Compose([
40+
transforms.ToTensor(),
41+
transforms.Normalize((0.5,), (0.5,))
42+
])
43+
44+
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
45+
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
46+
47+
cnn = CNN()
48+
cnn.train_cnn(trainloader)

0 commit comments

Comments
 (0)