From 0b62ac30ed2bb4ab5a82d6635fb1ac8a98ecdd3f Mon Sep 17 00:00:00 2001 From: Allen Shi Date: Sat, 29 Jun 2019 22:28:00 -0400 Subject: [PATCH 1/3] Fix issue #174 --- examples/transformer/bleu_tool.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/transformer/bleu_tool.py b/examples/transformer/bleu_tool.py index d98758af..f252670a 100755 --- a/examples/transformer/bleu_tool.py +++ b/examples/transformer/bleu_tool.py @@ -132,9 +132,12 @@ def compute_bleu(reference_corpus, if use_bp: ratio = translation_length / reference_length - if ratio == 0: + if ratio <= 0: bp = 0 - bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0 + elif ratio < 1.0: + bp = math.exp(1 - 1. / ratio) + else: + bp = 1.0 bleu = geo_mean * bp return np.float32(bleu) From b5c8bdd81be77eac5434a476f62c5bdef56f1838 Mon Sep 17 00:00:00 2001 From: Allen Shi Date: Mon, 1 Jul 2019 11:09:20 -0400 Subject: [PATCH 2/3] Fix Issue #174 and #175 --- examples/transformer/utils/data_utils.py | 50 +++--- texar/modules/embedders/position_embedders.py | 142 ++++++++++++------ 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/examples/transformer/utils/data_utils.py b/examples/transformer/utils/data_utils.py index 693e8a04..f69694c1 100644 --- a/examples/transformer/utils/data_utils.py +++ b/examples/transformer/utils/data_utils.py @@ -20,16 +20,27 @@ # pylint: disable=no-member + def load_data_numpy(input_dir, prefix): - train_data = np.load(os.path.join(input_dir,\ - prefix + 'train.npy'), encoding='latin1').tolist() - dev_data = np.load(os.path.join(input_dir,\ - prefix + 'valid.npy'), encoding='latin1').tolist() - test_data = np.load(os.path.join(input_dir,\ - prefix + 'test.npy'), encoding='latin1').tolist() - print('train data size:{}'.format(len(train_data))) + train_data = np.load( + os.path.join(input_dir, prefix + "train.npy"), + encoding="latin1", + allow_pickle=True, + ).tolist() + dev_data = np.load( + os.path.join(input_dir, prefix + "valid.npy"), + encoding="latin1", + allow_pickle=True, + ).tolist() + test_data = np.load( + os.path.join(input_dir, prefix + "test.npy"), + encoding="latin1", + allow_pickle=True, + ).tolist() + print("train data size:{}".format(len(train_data))) return train_data, dev_data, test_data + def seq2seq_pad_concat_convert(xy_batch, eos_id=2, bos_id=1): """ Args: @@ -55,21 +66,23 @@ def seq2seq_pad_concat_convert(xy_batch, eos_id=2, bos_id=1): y_block = _concat_examples(y_seqs, padding=0) # Add EOS - x_block = np.pad(x_block, ((0, 0), (0, 1)), 'constant', - constant_values=0) + x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0) for i_batch, seq in enumerate(x_seqs): x_block[i_batch, len(seq)] = eos_id - y_out_block = np.pad(y_block, ((0, 0), (0, 1)), 'constant', - constant_values=0) + y_out_block = np.pad( + y_block, ((0, 0), (0, 1)), "constant", constant_values=0 + ) for i_batch, seq in enumerate(y_seqs): y_out_block[i_batch, len(seq)] = eos_id # Add BOS in target language - y_in_block = np.pad(y_block, ((0, 0), (1, 0)), 'constant', - constant_values=bos_id) + y_in_block = np.pad( + y_block, ((0, 0), (1, 0)), "constant", constant_values=bos_id + ) return x_block, y_in_block, y_out_block + def source_pad_concat_convert(x_seqs, eos_id=2, bos_id=1): """ This function is used when testing the model without target input. @@ -77,14 +90,15 @@ def source_pad_concat_convert(x_seqs, eos_id=2, bos_id=1): x_block = _concat_examples(x_seqs, padding=0) # add EOS - x_block = np.pad(x_block, ((0, 0), (0, 1)), 'constant', constant_values=0) + x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0) for i_batch, seq in enumerate(x_seqs): x_block[i_batch, len(seq)] = eos_id return x_block + def _concat_examples(arrays, padding=0): if len(arrays) == 0: - raise ValueError('batch is empty') + raise ValueError("batch is empty") first_elem = arrays[0] assert isinstance(first_elem, np.ndarray) @@ -102,8 +116,8 @@ def _concat_examples(arrays, padding=0): result[(i,) + slices] = src return result + def write_words(words_list, filename): - with codecs.open(filename, 'w+', 'utf-8') as myfile: + with codecs.open(filename, "w+", "utf-8") as myfile: for words in words_list: - myfile.write(' '.join(words) + '\n') - + myfile.write(" ".join(words) + "\n") diff --git a/texar/modules/embedders/position_embedders.py b/texar/modules/embedders/position_embedders.py index ed31c887..5a248d9c 100644 --- a/texar/modules/embedders/position_embedders.py +++ b/texar/modules/embedders/position_embedders.py @@ -27,14 +27,16 @@ from texar.modules.embedders import embedder_utils from texar.utils.mode import is_train_mode from texar.utils.shapes import mask_sequences +from texar.utils.shapes import shape_list # pylint: disable=arguments-differ, invalid-name __all__ = [ "PositionEmbedder", - "SinusoidsPositionEmbedder", + "SinusoidsPositionEmbedder" ] + class PositionEmbedder(EmbedderBase): """Simple position embedder that maps position indexes into embeddings via lookup. @@ -68,18 +70,21 @@ def __init__(self, init_value=None, position_size=None, hparams=None): if init_value is None and position_size is None: raise ValueError( - "Either `init_value` or `position_size` is required.") + "Either `init_value` or `position_size` is required." + ) - self._init_parameterized_embedding(init_value, position_size, - self._hparams) + self._init_parameterized_embedding( + init_value, position_size, self._hparams + ) self._position_size = position_size if position_size is None: self._position_size = self._num_embeds if self._position_size != self._num_embeds: raise ValueError( - 'position_size must equal to init_value.shape[0].' - 'Got %d and %d' % (self._position_size, self._num_embeds)) + "position_size must equal to init_value.shape[0]." + "Got %d and %d" % (self._position_size, self._num_embeds) + ) self._built = True @@ -148,11 +153,13 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs): A `Tensor` of shape `shape(inputs) + embedding dimension`. """ # Gets embedder inputs + # pylint:disable=too-many-locals inputs = positions if positions is None: if sequence_length is None: raise ValueError( - 'Either `positions` or `sequence_length` is required.') + "Either `positions` or `sequence_length` is required." + ) max_length = tf.reduce_max(sequence_length) single_inputs = tf.range(start=0, limit=max_length, dtype=tf.int32) # Expands `single_inputs` to have shape [batch_size, max_length] @@ -166,38 +173,46 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs): # Gets dropout strategy st = self._hparams.dropout_strategy - if positions is None and st == 'item': + if positions is None and st == "item": # If `inputs` is based on `sequence_length`, then dropout # strategies 'item' and 'item_type' have the same effect, we # use 'item_type' to avoid unknown noise_shape in the 'item' # strategy - st = 'item_type' + st = "item_type" # Dropouts as 'item_type' before embedding - if st == 'item_type': + if st == "item_type": dropout_layer = self._get_dropout_layer( - self._hparams, dropout_strategy=st) + self._hparams, dropout_strategy=st + ) if dropout_layer: - embedding = dropout_layer.apply(inputs=embedding, - training=is_training) + embedding = dropout_layer.apply( + inputs=embedding, training=is_training + ) # Embeds outputs = tf.nn.embedding_lookup(embedding, inputs, **kwargs) # Dropouts as 'item' or 'elements' after embedding - if st != 'item_type': + if st != "item_type": dropout_layer = self._get_dropout_layer( - self._hparams, ids_rank=ids_rank, dropout_input=outputs, - dropout_strategy=st) + self._hparams, + ids_rank=ids_rank, + dropout_input=outputs, + dropout_strategy=st, + ) if dropout_layer: - outputs = dropout_layer.apply(inputs=outputs, - training=is_training) + outputs = dropout_layer.apply( + inputs=outputs, training=is_training + ) # Optionally masks if sequence_length is not None: outputs = mask_sequences( - outputs, sequence_length, - tensor_rank=len(inputs.shape.dims) + self._dim_rank) + outputs, + sequence_length, + tensor_rank=len(inputs.shape.dims) + self._dim_rank, + ) return outputs @@ -248,27 +263,38 @@ class SinusoidsPositionEmbedder(EmbedderBase): .. document private functions .. automethod:: _build """ + def __init__(self, position_size, hparams=None): EmbedderBase.__init__(self, hparams=hparams) - dim = self._hparams.dim - num_timescales = dim // 2 + self._num_embeds = position_size + self._dim = self._hparams.dim + self._cache_embeddings = self._hparams.cache_embeddings + + num_timescales = self._dim // 2 min_timescale = self._hparams.min_timescale max_timescale = self._hparams.max_timescale - positions = tf.to_float(tf.range(position_size, dtype=tf.int32)) - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (tf.to_float(num_timescales) - 1)) + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / (tf.to_float(num_timescales) - 1) inv_timescales = min_timescale * tf.exp( - tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) - scaled_time = tf.expand_dims(positions, 1) \ - * tf.expand_dims(inv_timescales, 0) - signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) - signal = tf.pad(signal, [[0, 0], [0, tf.mod(dim, 2)]]) - self.signal = signal - - def default_hparams(self): + tf.to_float(tf.range(num_timescales)) * -log_timescale_increment + ) + self.inv_timescales = inv_timescales + + if self._cache_embeddings: + if position_size is None: + raise ValueError( + "'position_size' must not be None when " + "'cache_embeddings' is set to True" + ) + positions = tf.to_float(tf.range(position_size, dtype=tf.int32)) + signal = self._compute_embeddings(positions) + self.signal = signal + + @staticmethod + def default_hparams(): """Returns a dictionary of hyperparameters with default values We use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different @@ -280,17 +306,42 @@ def default_hparams(self): 'min_timescale': 1.0, 'max_timescale': 10000.0, 'dim': 512, + 'cache_embeddings': True, 'name':'sinusoid_posisiton_embedder', } + + Here: + + `"cache_embeddings"`: bool + If `True`, precompute embeddings for positions in range + `[0, position_size - 1]`. This leads to faster lookup but + requires lookup indices to be within this range. + + If `False`, embeddings are computed on-the-fly during lookup. + Set to `False` if your application needs to handle sequences + of arbitrary length, or requires embeddings at negative + positions. """ hparams = { - 'min_timescale': 1.0, - 'max_timescale': 1.0e4, - 'dim': 512, - 'name':'sinusoid_posisiton_embedder', + "min_timescale": 1.0, + "max_timescale": 1.0e4, + "dim": 512, + "cache_embeddings": True, + "name": "sinusoid_posisiton_embedder", } return hparams + def _compute_embeddings(self, positions): + inv_timescales = self.inv_timescales + scaled_time = tf.reshape(tf.cast(positions, inv_timescales.dtype), + (-1, 1)) * tf.expand_dims(inv_timescales, 0) + signal = tf.concat( + [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1 + ) + signal = tf.pad(signal, [[0, 0], [0, tf.mod(self._dim, 2)]]) + signal = tf.reshape(signal, shape_list(positions) + [self._dim]) + return signal + def _build(self, positions=None, sequence_length=None): """Embeds. Either :attr:`positions` or :attr:`sequence_length` is required: @@ -312,18 +363,23 @@ def _build(self, positions=None, sequence_length=None): Returns: A `Tensor` of shape `[batch_size, max_time, dim]`. """ - inputs = positions + if positions is None: if sequence_length is None: raise ValueError( - 'Either `positions` or `sequence_length` is required.') + "Either `positions` or `sequence_length` is required." + ) max_length = tf.reduce_max(sequence_length) single_inputs = tf.range(start=0, limit=max_length, dtype=tf.int32) # Expands `single_inputs` to have shape [batch_size, max_length] expander = tf.expand_dims(tf.ones_like(sequence_length), -1) inputs = expander * tf.expand_dims(single_inputs, 0) + else: + inputs = positions - embedding = self.signal - outputs = tf.nn.embedding_lookup(embedding, inputs) - return outputs + if self._cache_embeddings: + outputs = tf.nn.embedding_lookup(self.signal, inputs) + else: + outputs = self._compute_embeddings(inputs) + return outputs From 192d67a1a144d726ebd4305ed7f4bbfac4688556 Mon Sep 17 00:00:00 2001 From: Zhiting Hu Date: Mon, 1 Jul 2019 20:58:55 -0400 Subject: [PATCH 3/3] Update SinusoidsPositionEmbedder doc and CHANGELOG --- CHANGELOG.md | 6 ++++-- texar/modules/embedders/position_embedders.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 778d37da..f15235d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,13 +12,15 @@ * Allow passing a Tensor to `output_layer` of decoders' constructors -- used for weight tie b/w the output layer and input embedding matrix. ([#126](https://github.com/asyml/texar/pull/126)) * `TransformerDecoder` constructor interface made exact the same with `RNN decoders` constructor interfaces. ([#126](https://github.com/asyml/texar/pull/126)) * Refactor decoder `Helper`s to allow two-argument `embedding_fn` (supporting for position embedding). ([#126](https://github.com/asyml/texar/pull/126)) +* Refactor `SinusoidsPositionEmbedder` to enable infinite large or negative position indexes. ([#176](https://github.com/asyml/texar/pull/176)) ### Fixes * Fix `texar.losses.reduce_batch_time` when `sequence` has dtype other than `tf.float32`. ([#143](https://github.com/asyml/texar/issues/143)) * Fix `texar.losses.reduce_dimensions` when `average_axes` or `sum_axes` is `int`. ([#141](https://github.com/asyml/texar/pull/141)) -* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([165](https://github.com/asyml/texar/pull/165)) -* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([168](https://github.com/asyml/texar/pull/168)) +* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([#165](https://github.com/asyml/texar/pull/165)) +* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([#168](https://github.com/asyml/texar/pull/168)) +* Fix transformer [bleu_tool.py](https://github.com/asyml/texar/blob/master/examples/transformer/bleu_tool.py) when `translation_length` is 0. ([#176](https://github.com/asyml/texar/pull/176)) ## [v0.2.0](https://github.com/asyml/texar/releases/tag/v0.2.0) (2019-04-09) diff --git a/texar/modules/embedders/position_embedders.py b/texar/modules/embedders/position_embedders.py index 5a248d9c..f623c4ea 100644 --- a/texar/modules/embedders/position_embedders.py +++ b/texar/modules/embedders/position_embedders.py @@ -258,7 +258,9 @@ class SinusoidsPositionEmbedder(EmbedderBase): Args: position_size (int): The number of possible positions, e.g., the maximum - sequence length. + sequence length. Set `position_size=None` and + `hparams['cache_embeddings']=False` to enable infinite large or + negative position indexes. .. document private functions .. automethod:: _build