-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_chatbot.py
99 lines (94 loc) · 4.59 KB
/
train_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
import math
import sys
import pickle
import glob
import os
import tensorflow as tf
from seq2seq_model import Seq2SeqModel
from corpora_tools import *
from corpora_get import *
path_l1_dict = "tmp/l1_dict.p"
path_l2_dict = "tmp/l2_dict.p"
model_dir = "tmp/chat"
model_checkpoints = model_dir + "/chat.ckpt"
def build_dataset(use_stored_dictionary=False):
sen_l1, sen_l2 = get_ubuntu_corpus_data()
clean_sen_l1 = [clean_sentence(s) for s in sen_l1] ### OTHERWISE IT DOES NOT RUN ON MY LAPTOP
clean_sen_l2 = [clean_sentence(s) for s in sen_l2] ### OTHERWISE IT DOES NOT RUN ON MY LAPTOP
filt_clean_sen_l1, filt_clean_sen_l2 = filter_sentence_length(clean_sen_l1, clean_sen_l2, max_len=20)
if not use_stored_dictionary:
#change dict_size according to input size
dict_l1 = create_indexed_dictionary(filt_clean_sen_l1, dict_size=20000, storage_path=path_l1_dict)
dict_l2 = create_indexed_dictionary(filt_clean_sen_l2, dict_size=20000, storage_path=path_l2_dict)
else:
dict_l1 = pickle.load(open(path_l1_dict, "rb"))
dict_l2 = pickle.load(open(path_l2_dict, "rb"))
dict_l1_length = len(dict_l1)
dict_l2_length = len(dict_l2)
idx_sentences_l1 = sentences_to_indexes(filt_clean_sen_l1, dict_l1)
idx_sentences_l2 = sentences_to_indexes(filt_clean_sen_l2, dict_l2)
max_length_l1 = extract_max_length(idx_sentences_l1)
max_length_l2 = extract_max_length(idx_sentences_l2)
data_set = prepare_sentences(idx_sentences_l1, idx_sentences_l2, max_length_l1, max_length_l2)
return (filt_clean_sen_l1, filt_clean_sen_l2), data_set, (max_length_l1, max_length_l2), (dict_l1_length, dict_l2_length)
def cleanup_checkpoints(model_dir, model_checkpoints):
for f in glob.glob(model_checkpoints + "*"):
os.remove(f)
try:
os.mkdir(model_dir)
except FileExistsError:
pass
def get_seq2seq_model(session, forward_only, dict_lengths, max_sentence_lengths, model_dir):
model = Seq2SeqModel(
source_vocab_size=dict_lengths[0],
target_vocab_size=dict_lengths[1],
buckets=[max_sentence_lengths],
size=256,
num_layers=2,
max_gradient_norm=5.0,
batch_size=128,
learning_rate=1.0,
learning_rate_decay_factor=0.99,
forward_only=forward_only)
#dtype=tf.float16)
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from {}".format(ckpt.model_checkpoint_path))
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.global_variables_initializer())
return model
def train():
with tf.Session() as sess:
model = get_seq2seq_model(sess, False, dict_lengths, max_sentence_lengths, model_dir)
# This is the training loop.
step_time, loss = 0.0, 0.0
current_step = 0
bucket = 0
steps_per_checkpoint = 100
max_steps = 10000 #change to a larger number later
while current_step < max_steps:
start_time = time.time()
encoder_inputs, decoder_inputs, target_weights = model.get_batch([data_set], bucket)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket, False)
step_time += (time.time() - start_time) / steps_per_checkpoint
loss += step_loss / steps_per_checkpoint
current_step += 1
if current_step % steps_per_checkpoint == 0:
perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
print ("global step {} learning rate {} step_time {} perplexity {}".format(
model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity))
sess.run (model.learning_rate_decay_op)
model.saver.save(sess, model_checkpoints, global_step = model.global_step)
step_time, loss = 0.0, 0.0
encoder_inputs, decoder_inputs, target_weights = model.get_batch([data_set], bucket)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket, True)
eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float("inf")
print (" eval: perplexity {}".format(eval_ppx))
sys.stdout.flush()
if __name__ == "__main__":
_, data_set, max_sentence_lengths, dict_lengths = build_dataset(False)
cleanup_checkpoints(model_dir, model_checkpoints)
train()