|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.optim as optim |
| 4 | +from torch.utils.data import DataLoader |
| 5 | +from torchvision import datasets, transforms |
| 6 | + |
| 7 | + |
| 8 | +class CNN(nn.Module): |
| 9 | + def __init__(self): |
| 10 | + super(CNN, self).__init__() |
| 11 | + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) |
| 12 | + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) |
| 13 | + self.pool = nn.MaxPool2d(2, 2) |
| 14 | + self.fc1 = nn.Linear(64 * 4 * 4, 128) |
| 15 | + self.fc2 = nn.Linear(128, 10) |
| 16 | + |
| 17 | + def forward(self, x): |
| 18 | + x = self.pool(nn.functional.relu(self.conv1(x))) |
| 19 | + x = self.pool(nn.functional.relu(self.conv2(x))) |
| 20 | + x = x.view(-1, 64 * 4 * 4) |
| 21 | + x = nn.functional.relu(self.fc1(x)) |
| 22 | + x = self.fc2(x) |
| 23 | + return nn.functional.log_softmax(x, dim=1) |
| 24 | + |
| 25 | + def train_cnn(self, trainloader, epochs=3): |
| 26 | + optimizer = optim.SGD(self.parameters(), lr=0.01) |
| 27 | + criterion = nn.NLLLoss() |
| 28 | + |
| 29 | + for epoch in range(epochs): |
| 30 | + for images, labels in trainloader: |
| 31 | + optimizer.zero_grad() |
| 32 | + output = self(images) |
| 33 | + loss = criterion(output, labels) |
| 34 | + loss.backward() |
| 35 | + optimizer.step() |
| 36 | + |
| 37 | + torch.save(self.state_dict(), "mnist_cnn_model.pth") |
| 38 | + |
| 39 | +transform = transforms.Compose([ |
| 40 | + transforms.ToTensor(), |
| 41 | + transforms.Normalize((0.5,), (0.5,)) |
| 42 | +]) |
| 43 | + |
| 44 | +trainset = datasets.MNIST('.', download=True, train=True, transform=transform) |
| 45 | +trainloader = DataLoader(trainset, batch_size=64, shuffle=True) |
| 46 | + |
| 47 | +cnn = CNN() |
| 48 | +cnn.train_cnn(trainloader) |
0 commit comments