7
7
from model import TextBackbone
8
8
from data import ReviewsDataset , load_data
9
9
10
+ from transformers import AdamW
11
+
10
12
def parse_arguments ():
11
13
parser = argparse .ArgumentParser (
12
14
description = 'Simple Sentiment Analysis with PyTorch and Transformers'
@@ -15,7 +17,7 @@ def parse_arguments():
15
17
parser .add_argument ('--n_classes' , default = 2 , type = int , help = 'number of classes' )
16
18
17
19
parser .add_argument ('--data_path' , type = str , default = 'data/data.txt' , help = 'the path of dataset' )
18
- parser .add_argument ('--batch_size' , default = 20 , type = int , help = 'batch size' )
20
+ parser .add_argument ('--batch_size' , default = 8 , type = int , help = 'batch size' )
19
21
20
22
parser .add_argument ('--epochs' , default = 50 , type = int , help = 'number of epochs tp train for' )
21
23
parser .add_argument ('--lr' , default = 1e-3 , type = float , help = 'learning rate' )
@@ -28,9 +30,12 @@ def parse_arguments():
28
30
def train (model , dataset , optimizer , device , batch_size , epochs ):
29
31
model .train ()
30
32
train_loader = torch .utils .data .DataLoader (dataset , batch_size = batch_size , shuffle = True )
31
- pbar = tqdm (train_loader )
33
+
34
+ min_loss = float ('inf' )
32
35
for epoch in range (epochs ):
36
+ pbar = tqdm (train_loader )
33
37
pbar .set_description ("Epoch {}:" .format (epoch ))
38
+ total_loss = 0
34
39
for batch in pbar :
35
40
batch = {key : value .to (device ) for key , value in batch .items ()}
36
41
optimizer .zero_grad ()
@@ -39,9 +44,13 @@ def train(model, dataset, optimizer, device, batch_size, epochs):
39
44
loss .backward ()
40
45
optimizer .step ()
41
46
pbar .set_postfix (loss = loss .item ())
47
+ total_loss += loss .item ()
42
48
43
- if epoch % 10 == 0 :
44
- model .save ('model_{}.pth' .format (epoch ))
49
+ if total_loss < min_loss :
50
+ min_loss = total_loss
51
+ torch .save (model .state_dict (), 'output/model_best.pth' )
52
+
53
+ print ("Epoch {}: Average loss: {}" .format (epoch , total_loss / len (train_loader )))
45
54
46
55
return model
47
56
@@ -51,7 +60,8 @@ def main():
51
60
reviews , targets = load_data (args .data_path )
52
61
dataset = ReviewsDataset (reviews , targets )
53
62
model = TextBackbone (num_classes = args .n_classes ).to (args .device )
54
- optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
63
+ optimizer = AdamW (model .parameters (),lr = 2e-5 , eps = 1e-8 )
64
+
55
65
model = train (model , dataset , optimizer , args .device , args .batch_size , args .epochs )
56
66
torch .save (model .state_dict (), 'model.pth' )
57
67
0 commit comments