File tree Expand file tree Collapse file tree 1 file changed +9
-13
lines changed Expand file tree Collapse file tree 1 file changed +9
-13
lines changed Original file line number Diff line number Diff line change 5
5
from torchvision import datasets , transforms
6
6
from torch .utils .data import DataLoader
7
7
import numpy as np
8
+ from cnn import CNN # Import the CNN class
8
9
9
10
# Step 1: Load MNIST Data and Preprocess
10
11
transform = transforms .Compose ([
16
17
trainloader = DataLoader (trainset , batch_size = 64 , shuffle = True )
17
18
18
19
# Step 2: Define the PyTorch Model
19
- class Net (nn .Module ):
20
- def __init__ (self ):
21
- super ().__init__ ()
22
- self .fc1 = nn .Linear (28 * 28 , 128 )
23
- self .fc2 = nn .Linear (128 , 64 )
24
- self .fc3 = nn .Linear (64 , 10 )
25
-
26
- def forward (self , x ):
27
- x = x .view (- 1 , 28 * 28 )
28
- x = nn .functional .relu (self .fc1 (x ))
29
- x = nn .functional .relu (self .fc2 (x ))
30
- x = self .fc3 (x )
31
- return nn .functional .log_softmax (x , dim = 1 )
20
+ # Create an instance of the CNN class
21
+ cnn = CNN ()
22
+
23
+ # Train the CNN
24
+ cnn .train (trainloader , lr = 0.001 , epochs = 10 )
25
+
26
+ # Save the trained model
27
+ cnn .save_model ("mnist_model.pth" )
32
28
33
29
# Step 3: Train the Model
34
30
model = Net ()
You can’t perform that action at this time.
0 commit comments