Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Delimiter-based answer extraction classes."""

from eva.language.models.postprocess.extract_answer.delimiter.discrete import (
ExtractDiscreteAnswerFromDelimiter,
)
from eva.language.models.postprocess.extract_answer.delimiter.free_form import (
ExtractAnswerFromDelimiter,
)

__all__ = [
"ExtractDiscreteAnswerFromDelimiter",
"ExtractAnswerFromDelimiter",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Postprocessing transforms for extracting discrete answers from delimiter-based responses."""

from typing import Dict

from typing_extensions import override

from eva.language.models.postprocess.extract_answer.base import (
ExtractDiscreteAnswerFromStructuredOutput,
)
from eva.language.utils.text import delimiter as delimiter_utils


class ExtractDiscreteAnswerFromDelimiter(ExtractDiscreteAnswerFromStructuredOutput):
"""Extracts discrete answers from delimiter-based text responses to int tensors."""

def __init__(
self,
mapping: Dict[str, int],
answer_key: str = "answer",
case_sensitive: bool = False,
raise_if_missing: bool = True,
missing_answer: int = -1,
missing_limit: int = 5,
delimiter: str = "####",
) -> None:
"""Initialize the transform.

Args:
mapping: Mapping from answer strings to integer IDs.
answer_key: The key within the structured object that stores the answer.
case_sensitive: Whether to treat mappings as case sensitive.
raise_if_missing: Whether to raise an error if an answer is missing
or not found in the mapping. If False, will return `missing_answer`
instead.
missing_answer: The integer value to return if the answer is missing
and `raise_if_missing` is False or the number of missing answers
are still below `missing_limit`.
missing_limit: The maximum number of missing responses before raising
an error, if `raise_if_missing` is True.
delimiter: The delimiter string to search for (default "####").
"""
super().__init__(
mapping=mapping,
answer_key=answer_key,
case_sensitive=case_sensitive,
raise_if_missing=raise_if_missing,
missing_answer=missing_answer,
missing_limit=missing_limit,
)
self.delimiter = delimiter

@override
def _extract_structured_data(self, value: str) -> Dict[str, str] | None:
"""Extract delimiter-based data from a string.

Args:
value: The input string containing delimiter text.

Returns:
Dict[str, str] | None: The extracted delimiter object or None if extraction failed.
"""
return delimiter_utils.extract_delimiter(
text=value,
delimiter=self.delimiter,
answer_options=list(self.mapping.keys()),
answer_key=self.answer_key,
case_sensitive=self.case_sensitive,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Postprocessing transforms for extracting free-form answers from delimiter-based responses."""

from typing import Dict

from typing_extensions import override

from eva.language.models.postprocess.extract_answer.base import ExtractAnswerFromStructuredOutput
from eva.language.utils.text import delimiter as delimiter_utils


class ExtractAnswerFromDelimiter(ExtractAnswerFromStructuredOutput):
"""Extracts free-form answers from delimiter-based text responses."""

def __init__(
self,
answer_key: str = "answer",
case_sensitive: bool = False,
raise_if_missing: bool = True,
missing_answer: int = -1,
missing_limit: int = 5,
delimiter: str = "####",
) -> None:
"""Initialize the transform.

Args:
answer_key: The key within the structured object that stores the answer.
case_sensitive: Whether to treat answers as case sensitive.
raise_if_missing: Whether to raise an error if an answer is missing.
If False, will return `missing_answer` instead.
missing_answer: The integer value to return if the answer is missing
and `raise_if_missing` is False or the number of missing answers
are still below `missing_limit`.
missing_limit: The maximum number of missing responses before raising
an error, if `raise_if_missing` is True.
delimiter: The delimiter string to search for (default "####").
"""
super().__init__(
answer_key=answer_key,
case_sensitive=case_sensitive,
raise_if_missing=raise_if_missing,
missing_answer=missing_answer,
missing_limit=missing_limit,
)
self.delimiter = delimiter

@override
def _extract_structured_data(self, value: str) -> Dict[str, str] | None:
"""Extract delimiter-based data from a string.

Args:
value: The input string containing delimiter text.

Returns:
Dict[str, str] | None: The extracted delimiter object or None if extraction failed.
"""
return delimiter_utils.extract_delimiter(
text=value,
delimiter=self.delimiter,
answer_key=self.answer_key,
case_sensitive=self.case_sensitive,
answer_options=None, # No validation for free-form answers
)
16 changes: 12 additions & 4 deletions src/eva/language/models/postprocess/extract_answer/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from typing import Literal

from eva.language.models.postprocess.extract_answer.base import ExtractAnswerFromStructuredOutput
from eva.language.models.postprocess.extract_answer.delimiter import (
ExtractAnswerFromDelimiter,
ExtractDiscreteAnswerFromDelimiter,
)
from eva.language.models.postprocess.extract_answer.json import (
ExtractAnswerFromJson,
ExtractDiscreteAnswerFromJson,
Expand All @@ -21,12 +25,12 @@ class ExtractDiscreteAnswer:
"""Factory for creating discrete answer extractors."""

def __new__(
cls, answer_format: Literal["json", "xml", "raw"], extract_kwargs: dict
cls, answer_format: Literal["json", "xml", "raw", "delimiter"], extract_kwargs: dict
) -> ExtractAnswerFromStructuredOutput:
"""Create a discrete answer extractor based on the answer format.

Args:
answer_format: The format of the answer to extract ('json', 'xml', or 'raw').
answer_format: The format of the answer to extract ('json', 'xml', 'raw', 'delimiter').
extract_kwargs: Keyword arguments passed to the extractor constructor.

Returns:
Expand All @@ -39,6 +43,8 @@ def __new__(
return ExtractDiscreteAnswerFromXml(**extract_kwargs)
case "raw":
return ExtractDiscreteAnswerFromRaw(**extract_kwargs)
case "delimiter":
return ExtractDiscreteAnswerFromDelimiter(**extract_kwargs)
case _:
raise ValueError(f"Unknown answer format: {answer_format}")

Expand All @@ -47,12 +53,12 @@ class ExtractAnswer:
"""Factory for creating answer extractors."""

def __new__(
cls, answer_format: Literal["json", "xml"], extract_kwargs: dict
cls, answer_format: Literal["json", "xml", "raw", "delimiter"], extract_kwargs: dict
) -> ExtractAnswerFromStructuredOutput:
"""Create an answer extractor based on the answer format.

Args:
answer_format: The format of the answer to extract ('json' or 'xml').
answer_format: The format of the answer to extract ('json', 'xml', 'raw', 'delimiter').
extract_kwargs: Keyword arguments passed to the extractor constructor.

Returns:
Expand All @@ -65,5 +71,7 @@ def __new__(
return ExtractAnswerFromXml(**extract_kwargs)
case "raw":
return ExtractAnswerFromRaw(**extract_kwargs)
case "delimiter":
return ExtractAnswerFromDelimiter(**extract_kwargs)
case _:
raise ValueError(f"Unknown answer format: {answer_format}")
24 changes: 16 additions & 8 deletions src/eva/language/prompts/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Prompt templating API."""

from eva.language.prompts.templates.base import PromptTemplate
from eva.language.prompts.templates.delimiter import (
DelimiterFreeFormQuestionPromptTemplate,
DelimiterMultipleChoicePromptTemplate,
)
from eva.language.prompts.templates.factory import (
FreeFormQuestionPromptTemplate,
MultipleChoicePromptTemplate,
Expand All @@ -13,17 +17,21 @@
RawFreeFormQuestionPromptTemplate,
RawMultipleChoicePromptTemplate,
)
from eva.language.prompts.templates.xml import XmlMultipleChoicePromptTemplate
from eva.language.prompts.templates.xml import (
XmlFreeFormQuestionPromptTemplate,
XmlMultipleChoicePromptTemplate,
)

__all__ = [
"JsonMultipleChoicePromptTemplate",
"RawMultipleChoicePromptTemplate",
"XmlMultipleChoicePromptTemplate",
"JsonFreeFormQuestionPromptTemplate",
"RawFreeFormQuestionPromptTemplate",
"PromptTemplate",
"FreeFormQuestionPromptTemplate",
"MultipleChoicePromptTemplate",
"PromptTemplate",
"JsonFreeFormQuestionPromptTemplate",
"JsonMultipleChoicePromptTemplate",
"FreeFormQuestionPromptTemplate",
"RawFreeFormQuestionPromptTemplate",
"RawMultipleChoicePromptTemplate",
"XmlFreeFormQuestionPromptTemplate",
"XmlMultipleChoicePromptTemplate",
"DelimiterFreeFormQuestionPromptTemplate",
"DelimiterMultipleChoicePromptTemplate",
]
13 changes: 13 additions & 0 deletions src/eva/language/prompts/templates/delimiter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Delimiter-based prompt templates."""

from eva.language.prompts.templates.delimiter.free_form import (
DelimiterFreeFormQuestionPromptTemplate,
)
from eva.language.prompts.templates.delimiter.multiple_choice import (
DelimiterMultipleChoicePromptTemplate,
)

__all__ = [
"DelimiterFreeFormQuestionPromptTemplate",
"DelimiterMultipleChoicePromptTemplate",
]
95 changes: 95 additions & 0 deletions src/eva/language/prompts/templates/delimiter/free_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Prompt templates for free-form questions with delimiter-based output format."""

# ruff: noqa: E501

from __future__ import annotations

import textwrap
from typing import Sequence

from jinja2 import Template
from typing_extensions import override

from eva.language.prompts.templates import base, typings
from eva.language.utils.text import format as format_utils


class DelimiterFreeFormQuestionPromptTemplate(base.PromptTemplate):
"""Prompt template for free-form questions with answers after #### delimiter."""

template: str = textwrap.dedent(
"""\
{{ preamble }}

{% if examples %}
Below are some examples of how to answer questions:

{% for ex in examples %}
Example {{ loop.index }}:
Question: {{ ex.question }}
{% if ex.context %}
Context: {{ ex.context }}
{% endif %}
Answer: {{ ex.answer }}
---
{% endfor %}
Now please answer the following question.
{% endif %}

Question: {{ question }}
{% if context %}
Context:
{{ context }}
{% endif %}

IMPORTANT: Think step-by-step before giving your final answer, then provide your final answer after "#### ".

{% if not examples %}
Example Answer:
Your explanation and reasoning can go here...
#### [your final answer here]
{% endif %}
"""
)
"""Base template to be rendered via Jinja2."""

def __init__(
self,
) -> None:
"""Initializes the prompt template."""
super().__init__()

@override
def render(
self,
*,
question: str,
context: str | Sequence[str] | None = None,
examples: Sequence[typings.QuestionAnswerExample] | None = None,
preamble: str | None = None,
) -> str:
"""Render the template with provided values.

Args:
question: The question to ask the model.
context: Supporting context text(s) for the question.
examples: A sequence of question & answer pairs to include as examples.
Expected format is a list of dicts with 'question', 'answer', and
optional 'context' keys.
preamble: Optional preamble text to include at the top of the prompt.

Returns:
The rendered prompt string.
"""
if not isinstance(question, str) or not question.strip():
raise ValueError("`question` must be a non-empty string.")

jinja_template = Template(self.template)
rendered = jinja_template.render(
question=question.strip(),
context=format_utils.format_list_items(context) if context else None,
examples=examples,
preamble=(preamble or "").strip(),
)

return format_utils.remove_multi_blank_lines(textwrap.dedent(rendered).strip() + "\n")
Loading