|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2018 The Google AI Language Team Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""BERT finetuning runner with TF-Hub.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +from __future__ import print_function |
| 20 | + |
| 21 | +import os |
| 22 | +import optimization |
| 23 | +import run_classifier |
| 24 | +import tokenization |
| 25 | +import tensorflow as tf |
| 26 | +import tensorflow_hub as hub |
| 27 | + |
| 28 | +flags = tf.flags |
| 29 | + |
| 30 | +FLAGS = flags.FLAGS |
| 31 | + |
| 32 | +flags.DEFINE_string( |
| 33 | + "bert_hub_module_handle", None, |
| 34 | + "Handle for the BERT TF-Hub module.") |
| 35 | + |
| 36 | + |
| 37 | +def create_model(is_training, input_ids, input_mask, segment_ids, labels, |
| 38 | + num_labels): |
| 39 | + """Creates a classification model.""" |
| 40 | + tags = set() |
| 41 | + if is_training: |
| 42 | + tags.add("train") |
| 43 | + bert_module = hub.Module( |
| 44 | + FLAGS.bert_hub_module_handle, |
| 45 | + tags=tags, |
| 46 | + trainable=True) |
| 47 | + bert_inputs = dict( |
| 48 | + input_ids=input_ids, |
| 49 | + input_mask=input_mask, |
| 50 | + segment_ids=segment_ids) |
| 51 | + bert_outputs = bert_module( |
| 52 | + inputs=bert_inputs, |
| 53 | + signature="tokens", |
| 54 | + as_dict=True) |
| 55 | + |
| 56 | + # In the demo, we are doing a simple classification task on the entire |
| 57 | + # segment. |
| 58 | + # |
| 59 | + # If you want to use the token-level output, use |
| 60 | + # bert_outputs["sequence_output"] instead. |
| 61 | + output_layer = bert_outputs["pooled_output"] |
| 62 | + |
| 63 | + hidden_size = output_layer.shape[-1].value |
| 64 | + |
| 65 | + output_weights = tf.get_variable( |
| 66 | + "output_weights", [num_labels, hidden_size], |
| 67 | + initializer=tf.truncated_normal_initializer(stddev=0.02)) |
| 68 | + |
| 69 | + output_bias = tf.get_variable( |
| 70 | + "output_bias", [num_labels], initializer=tf.zeros_initializer()) |
| 71 | + |
| 72 | + with tf.variable_scope("loss"): |
| 73 | + if is_training: |
| 74 | + # I.e., 0.1 dropout |
| 75 | + output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) |
| 76 | + |
| 77 | + logits = tf.matmul(output_layer, output_weights, transpose_b=True) |
| 78 | + logits = tf.nn.bias_add(logits, output_bias) |
| 79 | + log_probs = tf.nn.log_softmax(logits, axis=-1) |
| 80 | + |
| 81 | + one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) |
| 82 | + |
| 83 | + per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) |
| 84 | + loss = tf.reduce_mean(per_example_loss) |
| 85 | + |
| 86 | + return (loss, per_example_loss, logits) |
| 87 | + |
| 88 | + |
| 89 | +def model_fn_builder(num_labels, learning_rate, num_train_steps, |
| 90 | + num_warmup_steps, use_tpu): |
| 91 | + """Returns `model_fn` closure for TPUEstimator.""" |
| 92 | + |
| 93 | + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument |
| 94 | + """The `model_fn` for TPUEstimator.""" |
| 95 | + |
| 96 | + tf.logging.info("*** Features ***") |
| 97 | + for name in sorted(features.keys()): |
| 98 | + tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) |
| 99 | + |
| 100 | + input_ids = features["input_ids"] |
| 101 | + input_mask = features["input_mask"] |
| 102 | + segment_ids = features["segment_ids"] |
| 103 | + label_ids = features["label_ids"] |
| 104 | + |
| 105 | + is_training = (mode == tf.estimator.ModeKeys.TRAIN) |
| 106 | + |
| 107 | + (total_loss, per_example_loss, logits) = create_model( |
| 108 | + is_training, input_ids, input_mask, segment_ids, label_ids, num_labels) |
| 109 | + |
| 110 | + output_spec = None |
| 111 | + if mode == tf.estimator.ModeKeys.TRAIN: |
| 112 | + train_op = optimization.create_optimizer( |
| 113 | + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) |
| 114 | + |
| 115 | + output_spec = tf.contrib.tpu.TPUEstimatorSpec( |
| 116 | + mode=mode, |
| 117 | + loss=total_loss, |
| 118 | + train_op=train_op) |
| 119 | + elif mode == tf.estimator.ModeKeys.EVAL: |
| 120 | + |
| 121 | + def metric_fn(per_example_loss, label_ids, logits): |
| 122 | + predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) |
| 123 | + accuracy = tf.metrics.accuracy(label_ids, predictions) |
| 124 | + loss = tf.metrics.mean(per_example_loss) |
| 125 | + return { |
| 126 | + "eval_accuracy": accuracy, |
| 127 | + "eval_loss": loss, |
| 128 | + } |
| 129 | + |
| 130 | + eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) |
| 131 | + output_spec = tf.contrib.tpu.TPUEstimatorSpec( |
| 132 | + mode=mode, |
| 133 | + loss=total_loss, |
| 134 | + eval_metrics=eval_metrics) |
| 135 | + else: |
| 136 | + raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) |
| 137 | + |
| 138 | + return output_spec |
| 139 | + |
| 140 | + return model_fn |
| 141 | + |
| 142 | + |
| 143 | +def create_tokenizer_from_hub_module(): |
| 144 | + """Get the vocab file and casing info from the Hub module.""" |
| 145 | + with tf.Graph().as_default(): |
| 146 | + bert_module = hub.Module(FLAGS.bert_hub_module_handle) |
| 147 | + tokenization_info = bert_module(signature="tokenization_info", as_dict=True) |
| 148 | + with tf.Session() as sess: |
| 149 | + vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], |
| 150 | + tokenization_info["do_lower_case"]]) |
| 151 | + return tokenization.FullTokenizer( |
| 152 | + vocab_file=vocab_file, do_lower_case=do_lower_case) |
| 153 | + |
| 154 | + |
| 155 | +def main(_): |
| 156 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 157 | + |
| 158 | + processors = { |
| 159 | + "cola": run_classifier.ColaProcessor, |
| 160 | + "mnli": run_classifier.MnliProcessor, |
| 161 | + "mrpc": run_classifier.MrpcProcessor, |
| 162 | + } |
| 163 | + |
| 164 | + if not FLAGS.do_train and not FLAGS.do_eval: |
| 165 | + raise ValueError("At least one of `do_train` or `do_eval` must be True.") |
| 166 | + |
| 167 | + tf.gfile.MakeDirs(FLAGS.output_dir) |
| 168 | + |
| 169 | + task_name = FLAGS.task_name.lower() |
| 170 | + |
| 171 | + if task_name not in processors: |
| 172 | + raise ValueError("Task not found: %s" % (task_name)) |
| 173 | + |
| 174 | + processor = processors[task_name]() |
| 175 | + |
| 176 | + label_list = processor.get_labels() |
| 177 | + |
| 178 | + tokenizer = create_tokenizer_from_hub_module() |
| 179 | + |
| 180 | + tpu_cluster_resolver = None |
| 181 | + if FLAGS.use_tpu and FLAGS.tpu_name: |
| 182 | + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( |
| 183 | + FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) |
| 184 | + |
| 185 | + is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 |
| 186 | + run_config = tf.contrib.tpu.RunConfig( |
| 187 | + cluster=tpu_cluster_resolver, |
| 188 | + master=FLAGS.master, |
| 189 | + model_dir=FLAGS.output_dir, |
| 190 | + save_checkpoints_steps=FLAGS.save_checkpoints_steps, |
| 191 | + tpu_config=tf.contrib.tpu.TPUConfig( |
| 192 | + iterations_per_loop=FLAGS.iterations_per_loop, |
| 193 | + num_shards=FLAGS.num_tpu_cores, |
| 194 | + per_host_input_for_training=is_per_host)) |
| 195 | + |
| 196 | + train_examples = None |
| 197 | + num_train_steps = None |
| 198 | + num_warmup_steps = None |
| 199 | + if FLAGS.do_train: |
| 200 | + train_examples = processor.get_train_examples(FLAGS.data_dir) |
| 201 | + num_train_steps = int( |
| 202 | + len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) |
| 203 | + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) |
| 204 | + |
| 205 | + model_fn = model_fn_builder( |
| 206 | + num_labels=len(label_list), |
| 207 | + learning_rate=FLAGS.learning_rate, |
| 208 | + num_train_steps=num_train_steps, |
| 209 | + num_warmup_steps=num_warmup_steps, |
| 210 | + use_tpu=FLAGS.use_tpu) |
| 211 | + |
| 212 | + # If TPU is not available, this will fall back to normal Estimator on CPU |
| 213 | + # or GPU. |
| 214 | + estimator = tf.contrib.tpu.TPUEstimator( |
| 215 | + use_tpu=FLAGS.use_tpu, |
| 216 | + model_fn=model_fn, |
| 217 | + config=run_config, |
| 218 | + train_batch_size=FLAGS.train_batch_size, |
| 219 | + eval_batch_size=FLAGS.eval_batch_size) |
| 220 | + |
| 221 | + if FLAGS.do_train: |
| 222 | + train_features = run_classifier.convert_examples_to_features( |
| 223 | + train_examples, label_list, FLAGS.max_seq_length, tokenizer) |
| 224 | + tf.logging.info("***** Running training *****") |
| 225 | + tf.logging.info(" Num examples = %d", len(train_examples)) |
| 226 | + tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) |
| 227 | + tf.logging.info(" Num steps = %d", num_train_steps) |
| 228 | + train_input_fn = run_classifier.input_fn_builder( |
| 229 | + features=train_features, |
| 230 | + seq_length=FLAGS.max_seq_length, |
| 231 | + is_training=True, |
| 232 | + drop_remainder=True) |
| 233 | + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) |
| 234 | + |
| 235 | + if FLAGS.do_eval: |
| 236 | + eval_examples = processor.get_dev_examples(FLAGS.data_dir) |
| 237 | + eval_features = run_classifier.convert_examples_to_features( |
| 238 | + eval_examples, label_list, FLAGS.max_seq_length, tokenizer) |
| 239 | + |
| 240 | + tf.logging.info("***** Running evaluation *****") |
| 241 | + tf.logging.info(" Num examples = %d", len(eval_examples)) |
| 242 | + tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) |
| 243 | + |
| 244 | + # This tells the estimator to run through the entire set. |
| 245 | + eval_steps = None |
| 246 | + # However, if running eval on the TPU, you will need to specify the |
| 247 | + # number of steps. |
| 248 | + if FLAGS.use_tpu: |
| 249 | + # Eval will be slightly WRONG on the TPU because it will truncate |
| 250 | + # the last batch. |
| 251 | + eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) |
| 252 | + |
| 253 | + eval_drop_remainder = True if FLAGS.use_tpu else False |
| 254 | + eval_input_fn = run_classifier.input_fn_builder( |
| 255 | + features=eval_features, |
| 256 | + seq_length=FLAGS.max_seq_length, |
| 257 | + is_training=False, |
| 258 | + drop_remainder=eval_drop_remainder) |
| 259 | + |
| 260 | + result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) |
| 261 | + |
| 262 | + output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") |
| 263 | + with tf.gfile.GFile(output_eval_file, "w") as writer: |
| 264 | + tf.logging.info("***** Eval results *****") |
| 265 | + for key in sorted(result.keys()): |
| 266 | + tf.logging.info(" %s = %s", key, str(result[key])) |
| 267 | + writer.write("%s = %s\n" % (key, str(result[key]))) |
| 268 | + |
| 269 | + |
| 270 | +if __name__ == "__main__": |
| 271 | + flags.mark_flag_as_required("data_dir") |
| 272 | + flags.mark_flag_as_required("task_name") |
| 273 | + flags.mark_flag_as_required("bert_hub_module_handle") |
| 274 | + flags.mark_flag_as_required("output_dir") |
| 275 | + tf.app.run() |
0 commit comments