@@ -40,8 +40,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
40
40
best_accuracy = Best_Result ()
41
41
model .train ()
42
42
for epoch in range (1 , args .epochs + 1 ):
43
- print ("\n ## 第{} 轮迭代,共计迭代 {} 次 !##\n " .format (epoch , args .epochs ))
44
- print ("now lr is {} \n " .format (optimizer .param_groups [0 ].get ("lr" )))
43
+ print ("\n ## The {} Epoch, All {} Epochs ! ##" .format (epoch , args .epochs ))
45
44
for batch in train_iter :
46
45
feature , target = batch .text , batch .label
47
46
feature .data .t_ (), target .data .sub_ (1 ) # batch first, index align
@@ -72,9 +71,9 @@ def train(train_iter, dev_iter, test_iter, model, args):
72
71
batch .batch_size ))
73
72
if steps % args .test_interval == 0 :
74
73
print ("\n Dev Accuracy: " , end = "" )
75
- eval (dev_iter , model , args , best_accuracy , test = False )
74
+ eval (dev_iter , model , args , best_accuracy , epoch , test = False )
76
75
print ("Test Accuracy: " , end = "" )
77
- eval (test_iter , model , args , best_accuracy , test = True )
76
+ eval (test_iter , model , args , best_accuracy , epoch , test = True )
78
77
if steps % args .save_interval == 0 :
79
78
if not os .path .isdir (args .save_dir ):
80
79
os .makedirs (args .save_dir )
@@ -87,7 +86,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
87
86
return model_count
88
87
89
88
90
- def eval (data_iter , model , args , best_accuracy , test = False ):
89
+ def eval (data_iter , model , args , best_accuracy , epoch , test = False ):
91
90
model .eval ()
92
91
corrects , avg_loss = 0 , 0
93
92
for batch in data_iter :
@@ -109,13 +108,14 @@ def eval(data_iter, model, args, best_accuracy, test=False):
109
108
if test is False :
110
109
if accuracy >= best_accuracy .best_dev_accuracy :
111
110
best_accuracy .best_dev_accuracy = accuracy
111
+ best_accuracy .best_epoch = epoch
112
112
best_accuracy .best_test = True
113
113
if test is True and best_accuracy .best_test is True :
114
114
best_accuracy .accuracy = accuracy
115
115
116
116
if test is True :
117
- print ("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}\n " .format (best_accuracy . best_dev_accuracy ,
118
- best_accuracy .accuracy ))
117
+ print ("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}, locate on {} epoch. \n " .format (
118
+ best_accuracy . best_dev_accuracy , best_accuracy .accuracy , best_accuracy . best_epoch ))
119
119
if test is True :
120
120
best_accuracy .best_test = False
121
121
0 commit comments