Skip to content

Commit

Permalink
Merge pull request #40 from SeanCraven314/minor_fix
Browse files Browse the repository at this point in the history
fix: Fix tensor shape error, during llava inference.
  • Loading branch information
Efficient-Large-Language-Model authored May 7, 2024
2 parents 1b0caf4 + 832904b commit f85297f
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion llava/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,18 @@ def call_for_batch(
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
if keyword_id.ndim == 3:
if (output_ids[0, -keyword_id.shape[0] :, None] == keyword_id).all():
return True
elif keyword_id.ndim == 2:
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
return True
else:
raise ValueError(
"Keyword tensor should have 2 or 3 dimensions, got {}".format(
keyword_id.ndim
)
)
return True
outputs = self.tokenizer.batch_decode(
output_ids[:, -offset:], skip_special_tokens=True
Expand Down

0 comments on commit f85297f

Please sign in to comment.