From 9d139001db62c8a326642eb2c8cfb7d1a4240f09 Mon Sep 17 00:00:00 2001 From: MaxFeucht Date: Wed, 12 Nov 2025 10:58:08 +0100 Subject: [PATCH] Add `delimiter` as formatting option --- .../extract_answer/delimiter/__init__.py | 13 ++ .../extract_answer/delimiter/discrete.py | 68 ++++++++++ .../extract_answer/delimiter/free_form.py | 62 ++++++++++ .../postprocess/extract_answer/factory.py | 16 ++- .../language/prompts/templates/__init__.py | 24 ++-- .../prompts/templates/delimiter/__init__.py | 13 ++ .../prompts/templates/delimiter/free_form.py | 95 ++++++++++++++ .../templates/delimiter/multiple_choice.py | 117 ++++++++++++++++++ src/eva/language/prompts/templates/factory.py | 16 ++- src/eva/language/utils/text/delimiter.py | 80 ++++++++++++ .../multimodal/models/wrappers/huggingface.py | 2 +- 11 files changed, 489 insertions(+), 17 deletions(-) create mode 100644 src/eva/language/models/postprocess/extract_answer/delimiter/__init__.py create mode 100644 src/eva/language/models/postprocess/extract_answer/delimiter/discrete.py create mode 100644 src/eva/language/models/postprocess/extract_answer/delimiter/free_form.py create mode 100644 src/eva/language/prompts/templates/delimiter/__init__.py create mode 100644 src/eva/language/prompts/templates/delimiter/free_form.py create mode 100644 src/eva/language/prompts/templates/delimiter/multiple_choice.py create mode 100644 src/eva/language/utils/text/delimiter.py diff --git a/src/eva/language/models/postprocess/extract_answer/delimiter/__init__.py b/src/eva/language/models/postprocess/extract_answer/delimiter/__init__.py new file mode 100644 index 000000000..7d2337d6f --- /dev/null +++ b/src/eva/language/models/postprocess/extract_answer/delimiter/__init__.py @@ -0,0 +1,13 @@ +"""Delimiter-based answer extraction classes.""" + +from eva.language.models.postprocess.extract_answer.delimiter.discrete import ( + ExtractDiscreteAnswerFromDelimiter, +) +from eva.language.models.postprocess.extract_answer.delimiter.free_form import ( + ExtractAnswerFromDelimiter, +) + +__all__ = [ + "ExtractDiscreteAnswerFromDelimiter", + "ExtractAnswerFromDelimiter", +] diff --git a/src/eva/language/models/postprocess/extract_answer/delimiter/discrete.py b/src/eva/language/models/postprocess/extract_answer/delimiter/discrete.py new file mode 100644 index 000000000..98e236c98 --- /dev/null +++ b/src/eva/language/models/postprocess/extract_answer/delimiter/discrete.py @@ -0,0 +1,68 @@ +"""Postprocessing transforms for extracting discrete answers from delimiter-based responses.""" + +from typing import Dict + +from typing_extensions import override + +from eva.language.models.postprocess.extract_answer.base import ( + ExtractDiscreteAnswerFromStructuredOutput, +) +from eva.language.utils.text import delimiter as delimiter_utils + + +class ExtractDiscreteAnswerFromDelimiter(ExtractDiscreteAnswerFromStructuredOutput): + """Extracts discrete answers from delimiter-based text responses to int tensors.""" + + def __init__( + self, + mapping: Dict[str, int], + answer_key: str = "answer", + case_sensitive: bool = False, + raise_if_missing: bool = True, + missing_answer: int = -1, + missing_limit: int = 5, + delimiter: str = "####", + ) -> None: + """Initialize the transform. + + Args: + mapping: Mapping from answer strings to integer IDs. + answer_key: The key within the structured object that stores the answer. + case_sensitive: Whether to treat mappings as case sensitive. + raise_if_missing: Whether to raise an error if an answer is missing + or not found in the mapping. If False, will return `missing_answer` + instead. + missing_answer: The integer value to return if the answer is missing + and `raise_if_missing` is False or the number of missing answers + are still below `missing_limit`. + missing_limit: The maximum number of missing responses before raising + an error, if `raise_if_missing` is True. + delimiter: The delimiter string to search for (default "####"). + """ + super().__init__( + mapping=mapping, + answer_key=answer_key, + case_sensitive=case_sensitive, + raise_if_missing=raise_if_missing, + missing_answer=missing_answer, + missing_limit=missing_limit, + ) + self.delimiter = delimiter + + @override + def _extract_structured_data(self, value: str) -> Dict[str, str] | None: + """Extract delimiter-based data from a string. + + Args: + value: The input string containing delimiter text. + + Returns: + Dict[str, str] | None: The extracted delimiter object or None if extraction failed. + """ + return delimiter_utils.extract_delimiter( + text=value, + delimiter=self.delimiter, + answer_options=list(self.mapping.keys()), + answer_key=self.answer_key, + case_sensitive=self.case_sensitive, + ) diff --git a/src/eva/language/models/postprocess/extract_answer/delimiter/free_form.py b/src/eva/language/models/postprocess/extract_answer/delimiter/free_form.py new file mode 100644 index 000000000..50c84bd1a --- /dev/null +++ b/src/eva/language/models/postprocess/extract_answer/delimiter/free_form.py @@ -0,0 +1,62 @@ +"""Postprocessing transforms for extracting free-form answers from delimiter-based responses.""" + +from typing import Dict + +from typing_extensions import override + +from eva.language.models.postprocess.extract_answer.base import ExtractAnswerFromStructuredOutput +from eva.language.utils.text import delimiter as delimiter_utils + + +class ExtractAnswerFromDelimiter(ExtractAnswerFromStructuredOutput): + """Extracts free-form answers from delimiter-based text responses.""" + + def __init__( + self, + answer_key: str = "answer", + case_sensitive: bool = False, + raise_if_missing: bool = True, + missing_answer: int = -1, + missing_limit: int = 5, + delimiter: str = "####", + ) -> None: + """Initialize the transform. + + Args: + answer_key: The key within the structured object that stores the answer. + case_sensitive: Whether to treat answers as case sensitive. + raise_if_missing: Whether to raise an error if an answer is missing. + If False, will return `missing_answer` instead. + missing_answer: The integer value to return if the answer is missing + and `raise_if_missing` is False or the number of missing answers + are still below `missing_limit`. + missing_limit: The maximum number of missing responses before raising + an error, if `raise_if_missing` is True. + delimiter: The delimiter string to search for (default "####"). + """ + super().__init__( + answer_key=answer_key, + case_sensitive=case_sensitive, + raise_if_missing=raise_if_missing, + missing_answer=missing_answer, + missing_limit=missing_limit, + ) + self.delimiter = delimiter + + @override + def _extract_structured_data(self, value: str) -> Dict[str, str] | None: + """Extract delimiter-based data from a string. + + Args: + value: The input string containing delimiter text. + + Returns: + Dict[str, str] | None: The extracted delimiter object or None if extraction failed. + """ + return delimiter_utils.extract_delimiter( + text=value, + delimiter=self.delimiter, + answer_key=self.answer_key, + case_sensitive=self.case_sensitive, + answer_options=None, # No validation for free-form answers + ) diff --git a/src/eva/language/models/postprocess/extract_answer/factory.py b/src/eva/language/models/postprocess/extract_answer/factory.py index c3d3b3ef1..da446cbaf 100644 --- a/src/eva/language/models/postprocess/extract_answer/factory.py +++ b/src/eva/language/models/postprocess/extract_answer/factory.py @@ -3,6 +3,10 @@ from typing import Literal from eva.language.models.postprocess.extract_answer.base import ExtractAnswerFromStructuredOutput +from eva.language.models.postprocess.extract_answer.delimiter import ( + ExtractAnswerFromDelimiter, + ExtractDiscreteAnswerFromDelimiter, +) from eva.language.models.postprocess.extract_answer.json import ( ExtractAnswerFromJson, ExtractDiscreteAnswerFromJson, @@ -21,12 +25,12 @@ class ExtractDiscreteAnswer: """Factory for creating discrete answer extractors.""" def __new__( - cls, answer_format: Literal["json", "xml", "raw"], extract_kwargs: dict + cls, answer_format: Literal["json", "xml", "raw", "delimiter"], extract_kwargs: dict ) -> ExtractAnswerFromStructuredOutput: """Create a discrete answer extractor based on the answer format. Args: - answer_format: The format of the answer to extract ('json', 'xml', or 'raw'). + answer_format: The format of the answer to extract ('json', 'xml', 'raw', 'delimiter'). extract_kwargs: Keyword arguments passed to the extractor constructor. Returns: @@ -39,6 +43,8 @@ def __new__( return ExtractDiscreteAnswerFromXml(**extract_kwargs) case "raw": return ExtractDiscreteAnswerFromRaw(**extract_kwargs) + case "delimiter": + return ExtractDiscreteAnswerFromDelimiter(**extract_kwargs) case _: raise ValueError(f"Unknown answer format: {answer_format}") @@ -47,12 +53,12 @@ class ExtractAnswer: """Factory for creating answer extractors.""" def __new__( - cls, answer_format: Literal["json", "xml"], extract_kwargs: dict + cls, answer_format: Literal["json", "xml", "raw", "delimiter"], extract_kwargs: dict ) -> ExtractAnswerFromStructuredOutput: """Create an answer extractor based on the answer format. Args: - answer_format: The format of the answer to extract ('json' or 'xml'). + answer_format: The format of the answer to extract ('json', 'xml', 'raw', 'delimiter'). extract_kwargs: Keyword arguments passed to the extractor constructor. Returns: @@ -65,5 +71,7 @@ def __new__( return ExtractAnswerFromXml(**extract_kwargs) case "raw": return ExtractAnswerFromRaw(**extract_kwargs) + case "delimiter": + return ExtractAnswerFromDelimiter(**extract_kwargs) case _: raise ValueError(f"Unknown answer format: {answer_format}") diff --git a/src/eva/language/prompts/templates/__init__.py b/src/eva/language/prompts/templates/__init__.py index 23b1b5563..7ed1e9645 100644 --- a/src/eva/language/prompts/templates/__init__.py +++ b/src/eva/language/prompts/templates/__init__.py @@ -1,6 +1,10 @@ """Prompt templating API.""" from eva.language.prompts.templates.base import PromptTemplate +from eva.language.prompts.templates.delimiter import ( + DelimiterFreeFormQuestionPromptTemplate, + DelimiterMultipleChoicePromptTemplate, +) from eva.language.prompts.templates.factory import ( FreeFormQuestionPromptTemplate, MultipleChoicePromptTemplate, @@ -13,17 +17,21 @@ RawFreeFormQuestionPromptTemplate, RawMultipleChoicePromptTemplate, ) -from eva.language.prompts.templates.xml import XmlMultipleChoicePromptTemplate +from eva.language.prompts.templates.xml import ( + XmlFreeFormQuestionPromptTemplate, + XmlMultipleChoicePromptTemplate, +) __all__ = [ - "JsonMultipleChoicePromptTemplate", - "RawMultipleChoicePromptTemplate", - "XmlMultipleChoicePromptTemplate", - "JsonFreeFormQuestionPromptTemplate", - "RawFreeFormQuestionPromptTemplate", + "PromptTemplate", "FreeFormQuestionPromptTemplate", "MultipleChoicePromptTemplate", - "PromptTemplate", + "JsonFreeFormQuestionPromptTemplate", "JsonMultipleChoicePromptTemplate", - "FreeFormQuestionPromptTemplate", + "RawFreeFormQuestionPromptTemplate", + "RawMultipleChoicePromptTemplate", + "XmlFreeFormQuestionPromptTemplate", + "XmlMultipleChoicePromptTemplate", + "DelimiterFreeFormQuestionPromptTemplate", + "DelimiterMultipleChoicePromptTemplate", ] diff --git a/src/eva/language/prompts/templates/delimiter/__init__.py b/src/eva/language/prompts/templates/delimiter/__init__.py new file mode 100644 index 000000000..245419733 --- /dev/null +++ b/src/eva/language/prompts/templates/delimiter/__init__.py @@ -0,0 +1,13 @@ +"""Delimiter-based prompt templates.""" + +from eva.language.prompts.templates.delimiter.free_form import ( + DelimiterFreeFormQuestionPromptTemplate, +) +from eva.language.prompts.templates.delimiter.multiple_choice import ( + DelimiterMultipleChoicePromptTemplate, +) + +__all__ = [ + "DelimiterFreeFormQuestionPromptTemplate", + "DelimiterMultipleChoicePromptTemplate", +] diff --git a/src/eva/language/prompts/templates/delimiter/free_form.py b/src/eva/language/prompts/templates/delimiter/free_form.py new file mode 100644 index 000000000..93373366d --- /dev/null +++ b/src/eva/language/prompts/templates/delimiter/free_form.py @@ -0,0 +1,95 @@ +"""Prompt templates for free-form questions with delimiter-based output format.""" + +# ruff: noqa: E501 + +from __future__ import annotations + +import textwrap +from typing import Sequence + +from jinja2 import Template +from typing_extensions import override + +from eva.language.prompts.templates import base, typings +from eva.language.utils.text import format as format_utils + + +class DelimiterFreeFormQuestionPromptTemplate(base.PromptTemplate): + """Prompt template for free-form questions with answers after #### delimiter.""" + + template: str = textwrap.dedent( + """\ + {{ preamble }} + + {% if examples %} + Below are some examples of how to answer questions: + + {% for ex in examples %} + Example {{ loop.index }}: + Question: {{ ex.question }} + {% if ex.context %} + Context: {{ ex.context }} + {% endif %} + Answer: {{ ex.answer }} + --- + {% endfor %} + Now please answer the following question. + {% endif %} + + Question: {{ question }} + {% if context %} + Context: + {{ context }} + {% endif %} + + IMPORTANT: Think step-by-step before giving your final answer, then provide your final answer after "#### ". + + {% if not examples %} + Example Answer: + Your explanation and reasoning can go here... + #### [your final answer here] + {% endif %} + """ + ) + """Base template to be rendered via Jinja2.""" + + def __init__( + self, + ) -> None: + """Initializes the prompt template.""" + super().__init__() + + @override + def render( + self, + *, + question: str, + context: str | Sequence[str] | None = None, + examples: Sequence[typings.QuestionAnswerExample] | None = None, + preamble: str | None = None, + ) -> str: + """Render the template with provided values. + + Args: + question: The question to ask the model. + context: Supporting context text(s) for the question. + examples: A sequence of question & answer pairs to include as examples. + Expected format is a list of dicts with 'question', 'answer', and + optional 'context' keys. + preamble: Optional preamble text to include at the top of the prompt. + + Returns: + The rendered prompt string. + """ + if not isinstance(question, str) or not question.strip(): + raise ValueError("`question` must be a non-empty string.") + + jinja_template = Template(self.template) + rendered = jinja_template.render( + question=question.strip(), + context=format_utils.format_list_items(context) if context else None, + examples=examples, + preamble=(preamble or "").strip(), + ) + + return format_utils.remove_multi_blank_lines(textwrap.dedent(rendered).strip() + "\n") diff --git a/src/eva/language/prompts/templates/delimiter/multiple_choice.py b/src/eva/language/prompts/templates/delimiter/multiple_choice.py new file mode 100644 index 000000000..b8845e621 --- /dev/null +++ b/src/eva/language/prompts/templates/delimiter/multiple_choice.py @@ -0,0 +1,117 @@ +"""Prompt templates for multiple choice questions with delimiter-based output format.""" + +# ruff: noqa: E501 + +from __future__ import annotations + +import string +import textwrap +from typing import Sequence + +from jinja2 import Template +from typing_extensions import override + +from eva.language.prompts.templates import base, typings +from eva.language.utils.text import format as format_utils + + +class DelimiterMultipleChoicePromptTemplate(base.PromptTemplate): + """Prompt template for multiple choice questions with answers after #### delimiter.""" + + template: str = textwrap.dedent( + """\ + {{ preamble }} + + {% if examples %} + Below are some examples of how to answer questions: + + {% for ex in examples %} + Example {{ loop.index }}: + Question: {{ ex.question }} + Answer: {{ ex.answer }} + --- + {% endfor %} + Now please answer the following question. + {% endif %} + + {{ question }} + {% if context %} + Context: + {{ context }} + {% endif %} + + IMPORTANT: Think step-by-step before giving your final answer, then output the final answer after "#### ". + {% if use_option_letters -%} + The answer must be the letter (e.g., "A", "B", "C", ...) + corresponding to your chosen option from the list below: + {%- else -%} + The answer must exactly match one of the options listed below: + {%- endif %} + {{ answer_options }} + + {% if not examples %} + Example Answer: + Your explanation for why you chose this answer can go here... + #### {{ example_answer }} + {% endif %} + """ + ) + """Base template to be rendered via Jinja2.""" + + def __init__( + self, + ) -> None: + """Initializes the prompt template.""" + super().__init__() + + @override + def render( + self, + *, + question: str, + context: str | Sequence[str] | None, + answer_options: Sequence[str], + examples: Sequence[typings.QuestionAnswerExample] | None = None, + example_answer: str | None = None, + preamble: str | None = None, + use_option_letters: bool | None = None, + enable_cot: bool | None = None, + ) -> str: + """Render the template with provided values. + + Args: + question: The question to ask the model. + context: Supporting context text(s) for the question. + answer_options: Allowed answer options. + examples: A sequence of question & answer pairs to include as examples. + Expected format is a list of dicts with 'question', 'answer', and + optional 'context' keys. + example_answer: Optional example answer for the delimiter snippet. Defaults to first option. + preamble: Optional preamble text to include at the top of the prompt. + use_option_letters: Whether to prefix options with letters (A, B, C, ...). + enable_cot: This parameter is ignored for delimiter templates as CoT is always enabled. + + Returns: + The rendered prompt string. + """ + if not isinstance(question, str) or not question.strip(): + raise ValueError("`question` must be a non-empty string.") + + jinja_template = Template(self.template) + rendered = jinja_template.render( + question=question.strip(), + context=format_utils.format_list_items(context) if context else None, + answer_options=format_utils.format_list_items( + answer_options, style="letters" if use_option_letters else "bullets" + ), + examples=examples, + example_answer=( + example_answer + if isinstance(example_answer, str) + else (string.ascii_uppercase[0] if use_option_letters else answer_options[0]) + ).strip(), + preamble=(preamble or "").strip(), + use_option_letters=use_option_letters, + ) + + return format_utils.remove_multi_blank_lines(textwrap.dedent(rendered).strip() + "\n") diff --git a/src/eva/language/prompts/templates/factory.py b/src/eva/language/prompts/templates/factory.py index 4cbb8fc9b..379d72e23 100644 --- a/src/eva/language/prompts/templates/factory.py +++ b/src/eva/language/prompts/templates/factory.py @@ -5,6 +5,10 @@ from typing_extensions import override from eva.language.prompts.templates.base import PromptTemplate +from eva.language.prompts.templates.delimiter import ( + DelimiterFreeFormQuestionPromptTemplate, + DelimiterMultipleChoicePromptTemplate, +) from eva.language.prompts.templates.json import ( JsonFreeFormQuestionPromptTemplate, JsonMultipleChoicePromptTemplate, @@ -23,12 +27,12 @@ class FreeFormQuestionPromptTemplate(PromptTemplate): """Factory for free-form question prompt templates based on answer format.""" def __new__( - cls, answer_format: Literal["json", "xml", "raw"], **template_kwargs + cls, answer_format: Literal["json", "xml", "raw", "delimiter"], **template_kwargs ) -> PromptTemplate: """Create a free-form question prompt template based on the answer format. Args: - answer_format: The format to use for answers ('json', 'xml', or 'raw'). + answer_format: The format to use for answers ('json', 'xml', 'raw', or 'delimiter'). **template_kwargs: Keyword arguments passed to the template constructor. Returns: @@ -41,6 +45,8 @@ def __new__( return XmlFreeFormQuestionPromptTemplate(**template_kwargs) case "raw": return RawFreeFormQuestionPromptTemplate(**template_kwargs) + case "delimiter": + return DelimiterFreeFormQuestionPromptTemplate(**template_kwargs) case _: raise ValueError(f"Unknown answer format: {answer_format}") @@ -53,12 +59,12 @@ class MultipleChoicePromptTemplate(PromptTemplate): """Factory for Multiple Choice QA prompt templates based on answer format.""" def __new__( - cls, answer_format: Literal["json", "xml", "raw"], **template_kwargs + cls, answer_format: Literal["json", "xml", "raw", "delimiter"], **template_kwargs ) -> PromptTemplate: """Create a multiple-choice prompt template based on the answer format. Args: - answer_format: The format to use for answers ('json', 'xml', or 'raw'). + answer_format: The format to use for answers ('json', 'xml', 'raw', or 'delimiter'). **template_kwargs: Keyword arguments passed to the template constructor. Returns: @@ -71,6 +77,8 @@ def __new__( return XmlMultipleChoicePromptTemplate(**template_kwargs) case "raw": return RawMultipleChoicePromptTemplate(**template_kwargs) + case "delimiter": + return DelimiterMultipleChoicePromptTemplate(**template_kwargs) case _: raise ValueError(f"Unknown answer format: {answer_format}") diff --git a/src/eva/language/utils/text/delimiter.py b/src/eva/language/utils/text/delimiter.py new file mode 100644 index 000000000..f07d0ed5a --- /dev/null +++ b/src/eva/language/utils/text/delimiter.py @@ -0,0 +1,80 @@ +"""Delimiter-based text extraction utilities.""" + +from typing import List + +from loguru import logger + +from eva.language.utils.text.raw import _extract_answer_from_options + + +def extract_delimiter( + text: str, + delimiter: str = "####", + answer_key: str = "answer", + case_sensitive: bool = False, + answer_options: List[str] | None = None, +) -> dict | None: + """Extract answer from text after a delimiter marker. + + Extracts the content that appears after the specified delimiter (default "####"). + For multiple-choice questions, validates the extracted answer against provided options. + Returns the last occurrence if multiple delimiters are found. + + Args: + text: The input string containing the delimiter and answer. + delimiter: The delimiter string to search for (default "####"). + answer_key: The key to use in the returned dictionary. + case_sensitive: Whether to treat matching as case sensitive. + answer_options: Optional list of valid answer options for validation. + + Returns: + Dictionary with answer_key mapped to extracted answer, or None if no delimiter found. + + Examples: + >>> extract_delimiter("Let me think... #### B") + {'answer': 'B'} + + >>> extract_delimiter("Analysis: #### The answer is 42", answer_key="result") + {'result': 'The answer is 42'} + + >>> extract_delimiter("First #### A, but actually #### B") + {'answer': 'B'} + + >>> extract_delimiter("No delimiter here") + None + """ + if not text or not isinstance(text, str): + return None + + try: + parts = text.split(delimiter) + + if len(parts) < 2: + logger.debug(f"No delimiter '{delimiter}' found in text") + return None + + answer_text = parts[-1].strip() + + if not answer_text: + logger.warning(f"Empty answer after delimiter '{delimiter}'") + return None + + if answer_options: + extracted_answer = _extract_answer_from_options( + answer_text, answer_options, case_sensitive + ) + if extracted_answer: + result = extracted_answer if case_sensitive else extracted_answer.upper() + return {answer_key: result} + + logger.warning( + f"Could not match delimited answer '{answer_text}' " + f"to any of the provided options: {answer_options}" + ) + return None + + return {answer_key: answer_text} + + except Exception as e: + logger.warning(f"Failed to extract answer from delimiter text: {e}") + return None diff --git a/src/eva/multimodal/models/wrappers/huggingface.py b/src/eva/multimodal/models/wrappers/huggingface.py index 554d898c2..85ad848b9 100644 --- a/src/eva/multimodal/models/wrappers/huggingface.py +++ b/src/eva/multimodal/models/wrappers/huggingface.py @@ -189,7 +189,7 @@ def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[ output[:, :instruction_length], skip_special_tokens=False ) decoded_output = self.processor.batch_decode( # type: ignore - output[:, instruction_length:], skip_special_tokens=False + output[:, instruction_length:], skip_special_tokens=True ) logger.debug(f"Decoded input: {decoded_input}")