From 4132aceb467ba17a3c1862537ce280e16a55ca91 Mon Sep 17 00:00:00 2001 From: "Zhang, Yanli L" Date: Fri, 15 Nov 2024 18:30:28 -0800 Subject: [PATCH] Refactor resolve_beam to fix recursion depth issue Signed-off-by: Zhang, Yanli L --- optimum/habana/transformers/generation/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index f088974f3..8f220d978 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2764,12 +2764,17 @@ def finalize_beams(initial_ids, beam_trace, model_config, length_penalty): root = (float("-inf"), None, None, False) def resolve_beam(beam): - if beam == root: - return [] - score, prev, tok, is_finished = beam - rest = resolve_beam(prev) - rest.append(tok) - return rest + tokens = [] + current = beam + + while current != root: + score, prev, tok, is_finished = current + tokens.append(tok) + current = prev + + # Reverse tokens since we are adding them in reverse order + tokens.reverse() + return tokens prev_beams = [[root] * num_beams] * bs best = [root] * bs