Skip to content

Commit bee6030

Browse files
Adding TF Hub support
1 parent f39e881 commit bee6030

5 files changed

+294
-16
lines changed

Diff for: README.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# BERT
22

3+
**\*\*\*\*\* New February 7th, 2019: TfHub Module \*\*\*\*\***
4+
5+
BERT has been uploaded to [TensorFlow Hub](https://tfhub.dev). See
6+
`run_classifier_with_tfhub.py` for an example of how to use the TF Hub module.
7+
38
**\*\*\*\*\* New November 23rd, 2018: Un-normalized multilingual model + Thai +
49
Mongolian \*\*\*\*\***
510

Diff for: create_pretraining_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import collections
2222
import random
23-
import tensorflow as tf
2423
import tokenization
24+
import tensorflow as tf
2525

2626
flags = tf.flags
2727

Diff for: modeling.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import json
2424
import math
2525
import re
26+
import numpy as np
2627
import six
2728
import tensorflow as tf
2829

@@ -133,7 +134,7 @@ def __init__(self,
133134
input_ids,
134135
input_mask=None,
135136
token_type_ids=None,
136-
use_one_hot_embeddings=True,
137+
use_one_hot_embeddings=False,
137138
scope=None):
138139
"""Constructor for BertModel.
139140
@@ -145,9 +146,7 @@ def __init__(self,
145146
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
146147
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
147148
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
148-
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
149-
it is much faster if this is True, on the CPU or GPU, it is faster if
150-
this is False.
149+
embeddings or tf.embedding_lookup() for the word embeddings.
151150
scope: (optional) variable scope. Defaults to "bert".
152151
153152
Raises:
@@ -262,20 +261,20 @@ def get_embedding_table(self):
262261
return self.embedding_table
263262

264263

265-
def gelu(input_tensor):
264+
def gelu(x):
266265
"""Gaussian Error Linear Unit.
267266
268267
This is a smoother version of the RELU.
269268
Original paper: https://arxiv.org/abs/1606.08415
270-
271269
Args:
272-
input_tensor: float Tensor to perform activation.
270+
x: float Tensor to perform activation.
273271
274272
Returns:
275-
`input_tensor` with the GELU activation applied.
273+
`x` with the GELU activation applied.
276274
"""
277-
cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
278-
return input_tensor * cdf
275+
cdf = 0.5 * (1.0 + tf.tanh(
276+
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
277+
return x * cdf
279278

280279

281280
def get_activation(activation_string):
@@ -394,8 +393,7 @@ def embedding_lookup(input_ids,
394393
initializer_range: float. Embedding initialization range.
395394
word_embedding_name: string. Name of the embedding table.
396395
use_one_hot_embeddings: bool. If True, use one-hot method for word
397-
embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
398-
for TPUs.
396+
embeddings. If False, use `tf.gather()`.
399397
400398
Returns:
401399
float Tensor of shape [batch_size, seq_length, embedding_size].
@@ -413,12 +411,12 @@ def embedding_lookup(input_ids,
413411
shape=[vocab_size, embedding_size],
414412
initializer=create_initializer(initializer_range))
415413

414+
flat_input_ids = tf.reshape(input_ids, [-1])
416415
if use_one_hot_embeddings:
417-
flat_input_ids = tf.reshape(input_ids, [-1])
418416
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
419417
output = tf.matmul(one_hot_input_ids, embedding_table)
420418
else:
421-
output = tf.nn.embedding_lookup(embedding_table, input_ids)
419+
output = tf.gather(embedding_table, flat_input_ids)
422420

423421
input_shape = get_shape_list(input_ids)
424422

Diff for: run_classifier_with_tfhub.py

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Google AI Language Team Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""BERT finetuning runner with TF-Hub."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
import optimization
23+
import run_classifier
24+
import tokenization
25+
import tensorflow as tf
26+
import tensorflow_hub as hub
27+
28+
flags = tf.flags
29+
30+
FLAGS = flags.FLAGS
31+
32+
flags.DEFINE_string(
33+
"bert_hub_module_handle", None,
34+
"Handle for the BERT TF-Hub module.")
35+
36+
37+
def create_model(is_training, input_ids, input_mask, segment_ids, labels,
38+
num_labels):
39+
"""Creates a classification model."""
40+
tags = set()
41+
if is_training:
42+
tags.add("train")
43+
bert_module = hub.Module(
44+
FLAGS.bert_hub_module_handle,
45+
tags=tags,
46+
trainable=True)
47+
bert_inputs = dict(
48+
input_ids=input_ids,
49+
input_mask=input_mask,
50+
segment_ids=segment_ids)
51+
bert_outputs = bert_module(
52+
inputs=bert_inputs,
53+
signature="tokens",
54+
as_dict=True)
55+
56+
# In the demo, we are doing a simple classification task on the entire
57+
# segment.
58+
#
59+
# If you want to use the token-level output, use
60+
# bert_outputs["sequence_output"] instead.
61+
output_layer = bert_outputs["pooled_output"]
62+
63+
hidden_size = output_layer.shape[-1].value
64+
65+
output_weights = tf.get_variable(
66+
"output_weights", [num_labels, hidden_size],
67+
initializer=tf.truncated_normal_initializer(stddev=0.02))
68+
69+
output_bias = tf.get_variable(
70+
"output_bias", [num_labels], initializer=tf.zeros_initializer())
71+
72+
with tf.variable_scope("loss"):
73+
if is_training:
74+
# I.e., 0.1 dropout
75+
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
76+
77+
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
78+
logits = tf.nn.bias_add(logits, output_bias)
79+
log_probs = tf.nn.log_softmax(logits, axis=-1)
80+
81+
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
82+
83+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
84+
loss = tf.reduce_mean(per_example_loss)
85+
86+
return (loss, per_example_loss, logits)
87+
88+
89+
def model_fn_builder(num_labels, learning_rate, num_train_steps,
90+
num_warmup_steps, use_tpu):
91+
"""Returns `model_fn` closure for TPUEstimator."""
92+
93+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
94+
"""The `model_fn` for TPUEstimator."""
95+
96+
tf.logging.info("*** Features ***")
97+
for name in sorted(features.keys()):
98+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
99+
100+
input_ids = features["input_ids"]
101+
input_mask = features["input_mask"]
102+
segment_ids = features["segment_ids"]
103+
label_ids = features["label_ids"]
104+
105+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
106+
107+
(total_loss, per_example_loss, logits) = create_model(
108+
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels)
109+
110+
output_spec = None
111+
if mode == tf.estimator.ModeKeys.TRAIN:
112+
train_op = optimization.create_optimizer(
113+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
114+
115+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
116+
mode=mode,
117+
loss=total_loss,
118+
train_op=train_op)
119+
elif mode == tf.estimator.ModeKeys.EVAL:
120+
121+
def metric_fn(per_example_loss, label_ids, logits):
122+
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
123+
accuracy = tf.metrics.accuracy(label_ids, predictions)
124+
loss = tf.metrics.mean(per_example_loss)
125+
return {
126+
"eval_accuracy": accuracy,
127+
"eval_loss": loss,
128+
}
129+
130+
eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
131+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
132+
mode=mode,
133+
loss=total_loss,
134+
eval_metrics=eval_metrics)
135+
else:
136+
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
137+
138+
return output_spec
139+
140+
return model_fn
141+
142+
143+
def create_tokenizer_from_hub_module():
144+
"""Get the vocab file and casing info from the Hub module."""
145+
with tf.Graph().as_default():
146+
bert_module = hub.Module(FLAGS.bert_hub_module_handle)
147+
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
148+
with tf.Session() as sess:
149+
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
150+
tokenization_info["do_lower_case"]])
151+
return tokenization.FullTokenizer(
152+
vocab_file=vocab_file, do_lower_case=do_lower_case)
153+
154+
155+
def main(_):
156+
tf.logging.set_verbosity(tf.logging.INFO)
157+
158+
processors = {
159+
"cola": run_classifier.ColaProcessor,
160+
"mnli": run_classifier.MnliProcessor,
161+
"mrpc": run_classifier.MrpcProcessor,
162+
}
163+
164+
if not FLAGS.do_train and not FLAGS.do_eval:
165+
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
166+
167+
tf.gfile.MakeDirs(FLAGS.output_dir)
168+
169+
task_name = FLAGS.task_name.lower()
170+
171+
if task_name not in processors:
172+
raise ValueError("Task not found: %s" % (task_name))
173+
174+
processor = processors[task_name]()
175+
176+
label_list = processor.get_labels()
177+
178+
tokenizer = create_tokenizer_from_hub_module()
179+
180+
tpu_cluster_resolver = None
181+
if FLAGS.use_tpu and FLAGS.tpu_name:
182+
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
183+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
184+
185+
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
186+
run_config = tf.contrib.tpu.RunConfig(
187+
cluster=tpu_cluster_resolver,
188+
master=FLAGS.master,
189+
model_dir=FLAGS.output_dir,
190+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
191+
tpu_config=tf.contrib.tpu.TPUConfig(
192+
iterations_per_loop=FLAGS.iterations_per_loop,
193+
num_shards=FLAGS.num_tpu_cores,
194+
per_host_input_for_training=is_per_host))
195+
196+
train_examples = None
197+
num_train_steps = None
198+
num_warmup_steps = None
199+
if FLAGS.do_train:
200+
train_examples = processor.get_train_examples(FLAGS.data_dir)
201+
num_train_steps = int(
202+
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
203+
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
204+
205+
model_fn = model_fn_builder(
206+
num_labels=len(label_list),
207+
learning_rate=FLAGS.learning_rate,
208+
num_train_steps=num_train_steps,
209+
num_warmup_steps=num_warmup_steps,
210+
use_tpu=FLAGS.use_tpu)
211+
212+
# If TPU is not available, this will fall back to normal Estimator on CPU
213+
# or GPU.
214+
estimator = tf.contrib.tpu.TPUEstimator(
215+
use_tpu=FLAGS.use_tpu,
216+
model_fn=model_fn,
217+
config=run_config,
218+
train_batch_size=FLAGS.train_batch_size,
219+
eval_batch_size=FLAGS.eval_batch_size)
220+
221+
if FLAGS.do_train:
222+
train_features = run_classifier.convert_examples_to_features(
223+
train_examples, label_list, FLAGS.max_seq_length, tokenizer)
224+
tf.logging.info("***** Running training *****")
225+
tf.logging.info(" Num examples = %d", len(train_examples))
226+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
227+
tf.logging.info(" Num steps = %d", num_train_steps)
228+
train_input_fn = run_classifier.input_fn_builder(
229+
features=train_features,
230+
seq_length=FLAGS.max_seq_length,
231+
is_training=True,
232+
drop_remainder=True)
233+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
234+
235+
if FLAGS.do_eval:
236+
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
237+
eval_features = run_classifier.convert_examples_to_features(
238+
eval_examples, label_list, FLAGS.max_seq_length, tokenizer)
239+
240+
tf.logging.info("***** Running evaluation *****")
241+
tf.logging.info(" Num examples = %d", len(eval_examples))
242+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
243+
244+
# This tells the estimator to run through the entire set.
245+
eval_steps = None
246+
# However, if running eval on the TPU, you will need to specify the
247+
# number of steps.
248+
if FLAGS.use_tpu:
249+
# Eval will be slightly WRONG on the TPU because it will truncate
250+
# the last batch.
251+
eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
252+
253+
eval_drop_remainder = True if FLAGS.use_tpu else False
254+
eval_input_fn = run_classifier.input_fn_builder(
255+
features=eval_features,
256+
seq_length=FLAGS.max_seq_length,
257+
is_training=False,
258+
drop_remainder=eval_drop_remainder)
259+
260+
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
261+
262+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
263+
with tf.gfile.GFile(output_eval_file, "w") as writer:
264+
tf.logging.info("***** Eval results *****")
265+
for key in sorted(result.keys()):
266+
tf.logging.info(" %s = %s", key, str(result[key]))
267+
writer.write("%s = %s\n" % (key, str(result[key])))
268+
269+
270+
if __name__ == "__main__":
271+
flags.mark_flag_as_required("data_dir")
272+
flags.mark_flag_as_required("task_name")
273+
flags.mark_flag_as_required("bert_hub_module_handle")
274+
flags.mark_flag_as_required("output_dir")
275+
tf.app.run()

Diff for: tokenization_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import os
2020
import tempfile
21+
import tokenization
2122
import six
2223
import tensorflow as tf
23-
import tokenization
2424

2525

2626
class TokenizationTest(tf.test.TestCase):

0 commit comments

Comments
 (0)