Skip to content

Commit 092bec9

Browse files
authored
Use unpack_batch function across multimodal wrappers (#886)
* Move `unpack_batch` to `multimodal/utils` to use for `LiteLLMModel` * lint
1 parent 8f55f8d commit 092bec9

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

src/eva/multimodal/models/wrappers/huggingface.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from eva.language.utils.text import messages as language_message_utils
1414
from eva.multimodal.models.typings import TextImageBatch
1515
from eva.multimodal.models.wrappers import base
16+
from eva.multimodal.utils.batch import unpack_batch
1617
from eva.multimodal.utils.text import messages as message_utils
1718

1819

@@ -72,7 +73,7 @@ def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Te
7273
"pixel_values": ...
7374
}
7475
"""
75-
message_batch, image_batch, _, _ = self._unpack_batch(batch)
76+
message_batch, image_batch, _, _ = unpack_batch(batch)
7677
with_images = image_batch is not None
7778

7879
message_batch = language_message_utils.batch_insert_system_message(
@@ -158,11 +159,6 @@ def load_processor(self) -> Callable:
158159
**self.processor_kwargs,
159160
)
160161

161-
def _unpack_batch(self, batch: TextImageBatch | TextBatch) -> tuple:
162-
if isinstance(batch, TextImageBatch):
163-
return batch.text, batch.image, batch.target, batch.metadata
164-
return batch.text, None, batch.target, batch.metadata
165-
166162
def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
167163
"""Decode the model's batch output to text.
168164

src/eva/multimodal/models/wrappers/litellm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from eva.language.utils.text import messages as language_message_utils
1111
from eva.multimodal.models.typings import TextImageBatch
1212
from eva.multimodal.models.wrappers import base
13+
from eva.multimodal.utils.batch import unpack_batch
1314
from eva.multimodal.utils.text import messages as message_utils
1415

1516

@@ -43,7 +44,7 @@ def __init__(
4344

4445
@override
4546
def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
46-
message_batch, image_batch, _, _ = TextImageBatch(*batch)
47+
message_batch, image_batch, _, _ = unpack_batch(batch)
4748

4849
message_batch = language_message_utils.batch_insert_system_message(
4950
message_batch, self.system_message
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Multimodal batch utilities API."""
2+
3+
from eva.multimodal.utils.batch.unpack import unpack_batch
4+
5+
__all__ = ["unpack_batch"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Unpack batch utility function."""
2+
3+
from eva.language.models.typings import TextBatch
4+
from eva.multimodal.models.typings import TextImageBatch
5+
6+
7+
def unpack_batch(batch: TextImageBatch | TextBatch) -> tuple:
8+
"""Unpacks a TextImageBatch or TextBatch into its components."""
9+
if isinstance(batch, TextImageBatch):
10+
return batch.text, batch.image, batch.target, batch.metadata
11+
return batch.text, None, batch.target, batch.metadata

0 commit comments

Comments
 (0)