1
- from PIL import Image
1
+ import numpy as np
2
2
import torch
3
3
import torch .nn as nn
4
4
import torch .optim as optim
5
- from torchvision import datasets , transforms
5
+ from PIL import Image
6
6
from torch .utils .data import DataLoader
7
- import numpy as np
7
+ from torchvision import datasets , transforms
8
8
9
9
# Step 1: Load MNIST Data and Preprocess
10
10
transform = transforms .Compose ([
@@ -29,20 +29,29 @@ def forward(self, x):
29
29
x = nn .functional .relu (self .fc2 (x ))
30
30
x = self .fc3 (x )
31
31
return nn .functional .log_softmax (x , dim = 1 )
32
+ class Trainer :
33
+ def __init__ (self , learning_rate , model_path ):
34
+ self .model = Net ()
35
+ self .optimizer = optim .SGD (self .model .parameters (), lr = learning_rate )
36
+ self .criterion = nn .NLLLoss ()
37
+ self .model_path = model_path
32
38
33
- # Step 3: Train the Model
34
- model = Net ()
35
- optimizer = optim .SGD (model .parameters (), lr = 0.01 )
36
- criterion = nn .NLLLoss ()
39
+ def train (self , epochs ):
40
+ for epoch in range (epochs ):
41
+ for images , labels in trainloader :
42
+ self .optimizer .zero_grad ()
43
+ output = self .model (images )
44
+ loss = self .criterion (output , labels )
45
+ loss .backward ()
46
+ self .optimizer .step ()
47
+
48
+ def save_model (self ):
49
+ torch .save (self .model .state_dict (), self .model_path )
37
50
38
- # Training loop
39
- epochs = 3
40
- for epoch in range (epochs ):
41
- for images , labels in trainloader :
42
- optimizer .zero_grad ()
43
- output = model (images )
44
- loss = criterion (output , labels )
45
- loss .backward ()
46
- optimizer .step ()
47
51
48
- torch .save (model .state_dict (), "mnist_model.pth" )
52
+ # Step 3: Train the Model
53
+
54
+ # Now let's create a Trainer instance and train and save the model
55
+ trainer = Trainer (learning_rate = 0.01 , model_path = "mnist_model.pth" )
56
+ trainer .train (epochs = 3 )
57
+ trainer .save_model ()
0 commit comments