From 832904b890bd57609df976ac57f656effdfd18d6 Mon Sep 17 00:00:00 2001 From: Sean Craven Date: Mon, 6 May 2024 20:11:12 +0100 Subject: [PATCH] fix: Fix tensor shape error, during llava inference. --- llava/mm_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/llava/mm_utils.py b/llava/mm_utils.py index dc8ba56d..5c6627d8 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -269,7 +269,18 @@ def call_for_batch( keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids ] for keyword_id in self.keyword_ids: - if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): + if keyword_id.ndim == 3: + if (output_ids[0, -keyword_id.shape[0] :, None] == keyword_id).all(): + return True + elif keyword_id.ndim == 2: + if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): + return True + else: + raise ValueError( + "Keyword tensor should have 2 or 3 dimensions, got {}".format( + keyword_id.ndim + ) + ) return True outputs = self.tokenizer.batch_decode( output_ids[:, -offset:], skip_special_tokens=True