Skip to content

Commit

Permalink
Refactor resolve_beam to fix recursion depth issue
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Yanli L <[email protected]>
  • Loading branch information
Yanli2190 committed Nov 16, 2024
1 parent be631a4 commit 4132ace
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,12 +2764,17 @@ def finalize_beams(initial_ids, beam_trace, model_config, length_penalty):
root = (float("-inf"), None, None, False)

def resolve_beam(beam):
if beam == root:
return []
score, prev, tok, is_finished = beam
rest = resolve_beam(prev)
rest.append(tok)
return rest
tokens = []
current = beam

while current != root:
score, prev, tok, is_finished = current
tokens.append(tok)
current = prev

# Reverse tokens since we are adding them in reverse order
tokens.reverse()
return tokens

prev_beams = [[root] * num_beams] * bs
best = [root] * bs
Expand Down

0 comments on commit 4132ace

Please sign in to comment.