diff --git a/texar/utils/beam_search.py b/texar/utils/beam_search.py index 553828d2..58d341e3 100644 --- a/texar/utils/beam_search.py +++ b/texar/utils/beam_search.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Modifications copyright (C) 2018 Texar +# Modifications copyright (C) 2019 Texar # ============================================================================== """ Implemetation of beam seach with penalties. @@ -103,6 +103,7 @@ def compute_batch_indices(batch_size, beam_size): Args: batch_size: Batch size beam_size: Size of the beam. + Returns: batch_pos: [batch_size, beam_size] tensor of ids """ @@ -142,11 +143,12 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, grow_finished, we will need to return the length penalized scors. flags: Tensor of bools for sequences that say whether a sequence - has reached EOS or not + has reached EOS or not beam_size: int batch_size: int prefix: string that will prefix unique names for the ops run. states_to_gather: dict (possibly nested) of decoding states. + Returns: Tuple of (topk_seq [batch_size, beam_size, decode_length], @@ -184,14 +186,14 @@ def gather(tensor, name): def beam_search(symbols_to_logits_fn, - initial_ids, - beam_size, - decode_length, - vocab_size, - alpha, - eos_id, - states=None, - stop_early=True): + initial_ids, + beam_size, + decode_length, + vocab_size, + alpha, + eos_id, + states=None, + stop_early=True): """Beam search with length penalties. Requires a function that can take the currently decoded sybmols and @@ -222,20 +224,21 @@ def beam_search(symbols_to_logits_fn, Args: symbols_to_logits_fn: Interface to the model, to provide logits. - Shoud take [batch_size, decoded_ids] and return - [batch_size, vocab_size] + Should take [batch_size, decoded_ids] and return + [batch_size, vocab_size] initial_ids: Ids to start off the decoding, this will be the first - thing handed to symbols_to_logits_fn (after expanding to beam size) + thing handed to symbols_to_logits_fn (after expanding to beam size) [batch_size] beam_size: Size of the beam. decode_length: Number of steps to decode for. vocab_size: Size of the vocab, must equal the size of the logits - returned by symbols_to_logits_fn + returned by symbols_to_logits_fn alpha: alpha for length penalty. states: dict (possibly nested) of decoding states. eos_id: ID for end of sentence. stop_early: a boolean - stop once best sequence is provably - determined. + determined. + Returns: Tuple of (decoded beams [batch_size, beam_size, decode_length] @@ -282,12 +285,13 @@ def grow_finished(finished_seq, finished_scores, finished_flags, finished_flags: finished bools for each of these sequences. [batch_size, beam_size] curr_seq: current topk sequence that has been grown by one - position. + position. [batch_size, beam_size, current_decoded_length] curr_scores: scores for each of these sequences. [batch_size, - beam_size] + beam_size] curr_finished: Finished flags for each of these sequences. [batch_size, beam_size] + Returns: Tuple of (Topk sequences based on scores, @@ -321,7 +325,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, Args: curr_seq: current topk sequence that has been grown by one - position. + position. [batch_size, beam_size, i+1] curr_scores: scores for each of these sequences. [batch_size, beam_size] @@ -330,6 +334,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, curr_finished: Finished flags for each of these sequences. [batch_size, beam_size] states: dict (possibly nested) of decoding states. + Returns: Tuple of (Topk sequences based on scores, @@ -363,6 +368,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states): alive_log_probs: probabilities of these sequences. [batch_size, beam_size] states: dict (possibly nested) of decoding states. + Returns: Tuple of (Topk sequences extended by the next word, @@ -521,7 +527,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_in_finished: finished bools for each of these - sequences. [batch_size, beam_size] + sequences. [batch_size, beam_size] Returns: Bool.