-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathtrain.py
195 lines (158 loc) · 6.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import cloud
import os
import sys
import subprocess
import random
import tensorflow as tf
import numpy as np
import time
import logging
from .hparams.registry import get_hparams
from .models.registry import get_model
from .data.registry import get_input_fns
from .training.lr_schemes import get_lr
from .training.envs import get_env
from .training import flags
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
def init_flags():
tf.flags.DEFINE_string("env", None, "Which environment to use.") # required
tf.flags.DEFINE_string("hparams", None, "Which hparams to use.") # required
# Utility flags
tf.flags.DEFINE_string("hparam_override", "",
"Run-specific hparam settings to use.")
tf.flags.DEFINE_boolean("fresh", False, "Remove output_dir before running.")
tf.flags.DEFINE_integer("seed", None, "Random seed.")
tf.flags.DEFINE_integer("train_epochs", None,
"Number of training epochs to perform.")
tf.flags.DEFINE_integer("eval_steps", None,
"Number of evaluation steps to perform.")
# TPU flags
tf.flags.DEFINE_string("tpu_name", "", "Name of TPU(s)")
tf.flags.DEFINE_integer(
"tpu_iterations_per_loop", 1000,
"The number of training steps to run on TPU before"
"returning control to CPU.")
tf.flags.DEFINE_integer(
"tpu_shards", 8, "The number of TPU shards in the system "
"(a single Cloud TPU has 8 shards.")
tf.flags.DEFINE_boolean(
"tpu_summarize", False, "Save summaries for TensorBoard. "
"Warning: this will slow down execution.")
tf.flags.DEFINE_boolean("tpu_dedicated", False,
"Do not use preemptible TPUs.")
tf.flags.DEFINE_string("data_dir", None, "The data directory.")
tf.flags.DEFINE_string("output_dir", None, "The output directory.")
tf.flags.DEFINE_integer("eval_every", 1000,
"Number of steps between evaluations.")
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = None
def init_random_seeds():
tf.set_random_seed(FLAGS.seed)
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
def init_model(hparams_name):
flags.validate_flags(FLAGS)
tf.reset_default_graph()
hparams = get_hparams(hparams_name)
hparams = hparams.parse(FLAGS.hparam_override)
hparams = flags.update_hparams(FLAGS, hparams, hparams_name)
# set larger eval_every for TPUs to improve utilization
if FLAGS.env == "tpu":
FLAGS.eval_every = max(FLAGS.eval_every, 5000)
hparams.tpu_summarize = FLAGS.tpu_summarize
tf.logging.warn("\n-----------------------------------------\n"
"BEGINNING RUN:\n"
"\t hparams: %s\n"
"\t output_dir: %s\n"
"\t data_dir: %s\n"
"-----------------------------------------\n" %
(hparams_name, hparams.output_dir, hparams.data_dir))
return hparams
def construct_estimator(model_fn, hparams, tpu=None):
if hparams.use_tpu:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
tpu=tpu.name)
master = tpu_cluster_resolver.get_master()
config = tpu_config.RunConfig(
master=master,
evaluation_master=master,
model_dir=hparams.output_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=FLAGS.tpu_iterations_per_loop,
num_shards=FLAGS.tpu_shards),
save_checkpoints_steps=FLAGS.eval_every)
estimator = tpu_estimator.TPUEstimator(
use_tpu=hparams.use_tpu,
model_fn=model_fn,
model_dir=hparams.output_dir,
config=config,
train_batch_size=hparams.batch_size,
eval_batch_size=hparams.batch_size)
else:
gpu_config = tf.ConfigProto(allow_soft_placement=True)
gpu_config.gpu_options.allow_growth = True
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=FLAGS.eval_every, session_config=gpu_config)
estimator = tf.estimator.Estimator(
model_fn=tf.contrib.estimator.replicate_model_fn(model_fn),
model_dir=hparams.output_dir,
config=run_config)
return estimator
def _run(hparams_name):
"""Run training, evaluation and inference."""
hparams = init_model(hparams_name)
original_batch_size = hparams.batch_size
if tf.gfile.Exists(hparams.output_dir) and FLAGS.fresh:
tf.gfile.DeleteRecursively(hparams.output_dir)
if not tf.gfile.Exists(hparams.output_dir):
tf.gfile.MakeDirs(hparams.output_dir)
model_fn = get_model(hparams)
train_input_fn, eval_input_fn, test_input_fn = get_input_fns(hparams)
tpu = None
if hparams.use_tpu:
cloud.instance.tpu.clean()
tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)
estimator = construct_estimator(model_fn, hparams, tpu)
if not hparams.use_tpu:
features, labels = train_input_fn()
sess = tf.Session()
tf.train.get_or_create_global_step()
model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
sess.run(tf.global_variables_initializer())
# output metadata about the run
with tf.gfile.GFile(os.path.join(hparams.output_dir, 'hparams.txt'),
'w') as hparams_file:
hparams_file.write("{}\n".format(time.time()))
hparams_file.write("{}\n".format(str(hparams)))
def loop(steps=FLAGS.eval_every):
estimator.train(train_input_fn, steps=steps)
if eval_input_fn:
estimator.evaluate(eval_input_fn, steps=hparams.eval_steps, name="eval")
if test_input_fn:
estimator.evaluate(test_input_fn, steps=hparams.eval_steps, name="test")
loop(1)
steps = estimator.get_variable_value("global_step")
k = steps * original_batch_size / float(hparams.epoch_size)
while k <= hparams.train_epochs:
tf.logging.info("Beginning epoch %f / %d" % (k, hparams.train_epochs))
if tpu and not tpu.usable:
tpu.delete(async=True)
tpu = cloud.instance.tpu.get(preemptible=not FLAGS.tpu_dedicated)
estimator = construct_estimator(model_fn, hparams, tpu)
loop()
steps = estimator.get_variable_value("global_step")
k = steps * original_batch_size / float(hparams.epoch_size)
def main(_):
global FLAGS
FLAGS = tf.app.flags.FLAGS
init_random_seeds()
if FLAGS.env != "local":
cloud.connect()
for hparams_name in FLAGS.hparams.split(","):
_run(hparams_name)
if __name__ == "__main__":
init_flags()
tf.app.run()