diff --git a/src/eva/language/models/wrappers/base.py b/src/eva/language/models/wrappers/base.py index 548586255..bb9580b1c 100644 --- a/src/eva/language/models/wrappers/base.py +++ b/src/eva/language/models/wrappers/base.py @@ -21,19 +21,31 @@ class LanguageModel(base.BaseModel[TextBatch, List[str]]): expected by the `model_forward` method. """ + _default_system_prompt = ( + "You are a helpful language assistant. " + "You are able to process the text content that the user provides," + "and assist the user with a variety of tasks using natural language." + ) + def __init__( - self, system_prompt: str | None, output_transforms: Callable | None = None + self, system_prompt: str | None = None, output_transforms: Callable | None = None ) -> None: """Creates a new model instance. Args: - system_prompt: The system prompt to use for the model (optional). + system_prompt: The system prompt to use for the model. If set to None, + will use the default system prompt. If you don't want to use any + system prompt, you can set this to an empty string. output_transforms: Optional transforms to apply to the output of the model's forward pass. """ super().__init__(transforms=output_transforms) - self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None + self.system_message = ( + ModelSystemMessage(content=system_prompt or self._default_system_prompt) + if system_prompt != "" + else None + ) @override def forward(self, batch: TextBatch) -> List[str]: diff --git a/src/eva/multimodal/models/wrappers/base.py b/src/eva/multimodal/models/wrappers/base.py index 86bcd3d38..14e4ce617 100644 --- a/src/eva/multimodal/models/wrappers/base.py +++ b/src/eva/multimodal/models/wrappers/base.py @@ -21,19 +21,31 @@ class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]): expected by the `model_forward` method. """ + _default_system_prompt = ( + "You are a helpful vision and language assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ) + def __init__( - self, system_prompt: str | None, output_transforms: Callable | None = None + self, system_prompt: str | None = None, output_transforms: Callable | None = None ) -> None: """Creates a new model instance. Args: - system_prompt: The system prompt to use for the model (optional). + system_prompt: The system prompt to use for the model. If set to None, + will use the default system prompt. If you don't want to use any + system prompt, you can set this to an empty string. output_transforms: Optional transforms to apply to the output of the model's forward pass. """ super().__init__(transforms=output_transforms) - self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None + self.system_message = ( + ModelSystemMessage(content=system_prompt or self._default_system_prompt) + if system_prompt != "" + else None + ) @override def forward(self, batch: TextImageBatch) -> List[str]: