Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/eva/multimodal/models/modules/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from eva.core.models.modules import module
from eva.core.models.modules.utils import batch_postprocess
from eva.language.models.typings import ModelOutput
from eva.multimodal.models import wrappers
from eva.multimodal.models.typings import TextImageBatch


Expand Down Expand Up @@ -54,3 +55,11 @@ def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
"targets": targets,
"metadata": metadata,
} | output

@override
def configure_model(self) -> None:
model = (
self.model.model if isinstance(self.model, wrappers.ModelFromRegistry) else self.model
)
if hasattr(model, "configure_model"):
model.configure_model() # type: ignore
14 changes: 13 additions & 1 deletion src/eva/multimodal/models/wrappers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
generation_kwargs: Dict[str, Any] | None = None,
image_key: str = "image",
image_position: Literal["before_text", "after_text"] = "after_text",
initialize_model: bool = True,
):
"""Initialize the HuggingFace model wrapper.

Expand All @@ -59,6 +60,8 @@ def __init__(
generation_kwargs: Additional generation arguments.
image_key: The key used for image inputs in the chat template.
image_position: Position of the image in the input sequence.
initialize_model: Whether to load the model in __init__ or
delay until lightning's `configure_model` hook is called.
"""
super().__init__(system_prompt=system_prompt)

Expand All @@ -68,10 +71,12 @@ def __init__(
self.processor_kwargs = processor_kwargs or {}
self.generation_kwargs = self._default_generation_kwargs | (generation_kwargs or {})
self.image_key = image_key
self.is_loaded = False
self.image_position: Literal["before_text", "after_text"] = image_position

self.processor = self.load_processor()
self.model = self.load_model()
if initialize_model:
self.model = self.load_model()

@override
def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -147,6 +152,11 @@ def model_forward(self, batch: Dict[str, torch.Tensor]) -> ModelOutput:
attention_mask=batch.get("attention_mask"),
)

def configure_model(self) -> None:
"""Lightning hook to configure / load the model."""
if not self.is_loaded:
self.model = self.load_model()

@override
def load_model(self) -> nn.Module:
"""Setting up the model. Used for delayed model initialization.
Expand All @@ -166,6 +176,8 @@ def load_model(self) -> nn.Module:
if not hasattr(model, "generate"):
raise ValueError(f"Model {self.model_name_or_path} does not support generation. ")

self.is_loaded = True

return model

def load_processor(self) -> Callable:
Expand Down
Loading