2
2
import torch
3
3
import torch .nn as nn
4
4
import torch .optim as optim
5
- from cnn import CNN , train
6
5
from PIL import Image
7
6
from torch .utils .data import DataLoader
8
7
from torchvision import datasets , transforms
9
8
9
+ from cnn import CNN , train
10
+
10
11
# Step 1: Load MNIST Data and Preprocess
11
- transform = transforms .Compose ([
12
- transforms .ToTensor (),
13
- transforms .Normalize ((0.5 ,), (0.5 ,))
14
- ])
12
+ transform = transforms .Compose (
13
+ [transforms .ToTensor (), transforms .Normalize ((0.5 ,), (0.5 ,))]
14
+ )
15
15
16
- trainset = datasets .MNIST ('.' , download = True , train = True , transform = transform )
16
+ trainset = datasets .MNIST ("." , download = True , train = True , transform = transform )
17
17
trainloader = DataLoader (trainset , batch_size = 64 , shuffle = True )
18
18
19
19
20
-
21
-
22
-
23
-
24
-
25
-
26
-
27
-
28
-
29
-
30
-
31
-
32
-
33
-
34
-
35
20
model = Net ()
36
21
model = CNN ()
37
22
optimizer = optim .SGD (model .parameters (), lr = 0.01 )
41
26
train (model , trainloader , optimizer )
42
27
43
28
44
- torch .save (model .state_dict (), "mnist_model.pth" )
29
+ torch .save (model .state_dict (), "mnist_model.pth" )
0 commit comments