We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7284908 commit 655aa57Copy full SHA for 655aa57
src/cnn.py
@@ -0,0 +1,21 @@
1
+import torch
2
+import torch.nn as nn
3
+import torch.nn.functional as F
4
+
5
6
+class CNN(nn.Module):
7
+ def __init__(self):
8
+ super(CNN, self).__init__()
9
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
10
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
11
+ self.pool = nn.MaxPool2d(2, 2)
12
+ self.fc1 = nn.Linear(64 * 4 * 4, 128)
13
+ self.fc2 = nn.Linear(128, 10)
14
15
+ def forward(self, x):
16
+ x = self.pool(F.relu(self.conv1(x)))
17
+ x = self.pool(F.relu(self.conv2(x)))
18
+ x = x.view(-1, 64 * 4 * 4)
19
+ x = F.relu(self.fc1(x))
20
+ x = self.fc2(x)
21
+ return F.log_softmax(x, dim=1)
0 commit comments