Skip to content

Commit ae992ad

Browse files
committed
modify train
1 parent e514f17 commit ae992ad

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

Diff for: train_ALL_CNN.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
4040
best_accuracy = Best_Result()
4141
model.train()
4242
for epoch in range(1, args.epochs+1):
43-
print("\n## 第{} 轮迭代,共计迭代 {} 次 !##\n".format(epoch, args.epochs))
44-
print("now lr is {} \n".format(optimizer.param_groups[0].get("lr")))
43+
print("\n## The {} Epoch, All {} Epochs ! ##".format(epoch, args.epochs))
4544
for batch in train_iter:
4645
feature, target = batch.text, batch.label
4746
feature.data.t_(), target.data.sub_(1) # batch first, index align
@@ -72,9 +71,9 @@ def train(train_iter, dev_iter, test_iter, model, args):
7271
batch.batch_size))
7372
if steps % args.test_interval == 0:
7473
print("\nDev Accuracy: ", end="")
75-
eval(dev_iter, model, args, best_accuracy, test=False)
74+
eval(dev_iter, model, args, best_accuracy, epoch, test=False)
7675
print("Test Accuracy: ", end="")
77-
eval(test_iter, model, args, best_accuracy, test=True)
76+
eval(test_iter, model, args, best_accuracy, epoch, test=True)
7877
if steps % args.save_interval == 0:
7978
if not os.path.isdir(args.save_dir):
8079
os.makedirs(args.save_dir)
@@ -87,7 +86,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
8786
return model_count
8887

8988

90-
def eval(data_iter, model, args, best_accuracy, test=False):
89+
def eval(data_iter, model, args, best_accuracy, epoch, test=False):
9190
model.eval()
9291
corrects, avg_loss = 0, 0
9392
for batch in data_iter:
@@ -109,13 +108,14 @@ def eval(data_iter, model, args, best_accuracy, test=False):
109108
if test is False:
110109
if accuracy >= best_accuracy.best_dev_accuracy:
111110
best_accuracy.best_dev_accuracy = accuracy
111+
best_accuracy.best_epoch = epoch
112112
best_accuracy.best_test = True
113113
if test is True and best_accuracy.best_test is True:
114114
best_accuracy.accuracy = accuracy
115115

116116
if test is True:
117-
print("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}\n".format(best_accuracy.best_dev_accuracy,
118-
best_accuracy.accuracy))
117+
print("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}, locate on {} epoch.\n".format(
118+
best_accuracy.best_dev_accuracy, best_accuracy.accuracy, best_accuracy.best_epoch))
119119
if test is True:
120120
best_accuracy.best_test = False
121121

Diff for: train_ALL_LSTM.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
4040
best_accuracy = Best_Result()
4141
model.train()
4242
for epoch in range(1, args.epochs+1):
43-
print("\n## {} 轮迭代,共计迭代 {} 次 !##\n".format(epoch, args.epochs))
43+
print("\n## The {} Epoch, All {} Epochs ! ##".format(epoch, args.epochs))
4444
for batch in train_iter:
4545
feature, target = batch.text, batch.label.data.sub_(1)
4646
if args.cuda is True:
@@ -69,9 +69,9 @@ def train(train_iter, dev_iter, test_iter, model, args):
6969
batch.batch_size))
7070
if steps % args.test_interval == 0:
7171
print("\nDev Accuracy: ", end="")
72-
eval(dev_iter, model, args, best_accuracy, test=False)
72+
eval(dev_iter, model, args, best_accuracy, epoch, test=False)
7373
print("Test Accuracy: ", end="")
74-
eval(test_iter, model, args, best_accuracy, test=True)
74+
eval(test_iter, model, args, best_accuracy, epoch, test=True)
7575
if steps % args.save_interval == 0:
7676
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
7777
save_prefix = os.path.join(args.save_dir, 'snapshot')
@@ -83,7 +83,7 @@ def train(train_iter, dev_iter, test_iter, model, args):
8383
return model_count
8484

8585

86-
def eval(data_iter, model, args, best_accuracy, test=False):
86+
def eval(data_iter, model, args, best_accuracy, epoch, test=False):
8787
model.eval()
8888
corrects, avg_loss = 0, 0
8989
for batch in data_iter:
@@ -105,13 +105,14 @@ def eval(data_iter, model, args, best_accuracy, test=False):
105105
if test is False:
106106
if accuracy >= best_accuracy.best_dev_accuracy:
107107
best_accuracy.best_dev_accuracy = accuracy
108+
best_accuracy.best_epoch = epoch
108109
best_accuracy.best_test = True
109110
if test is True and best_accuracy.best_test is True:
110111
best_accuracy.accuracy = accuracy
111112

112113
if test is True:
113-
print("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}\n".format(best_accuracy.best_dev_accuracy,
114-
best_accuracy.accuracy))
114+
print("The Current Best Dev Accuracy: {:.4f}, and Test Accuracy is :{:.4f}, locate on {} epoch.\n".format(
115+
best_accuracy.best_dev_accuracy, best_accuracy.accuracy, best_accuracy.best_epoch))
115116
if test is True:
116117
best_accuracy.best_test = False
117118

0 commit comments

Comments
 (0)