Skip to content

Commit b3a3bf7

Browse files
committed
Add EvalAssist judges integration
Signed-off-by: Martín Santillán Cooper <[email protected]>
1 parent 92827a3 commit b3a3bf7

File tree

4 files changed

+386
-0
lines changed

4 files changed

+386
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from unitxt.api import create_dataset, evaluate
2+
from unitxt.evalassist_judge import EvalAssistLLMJudgeDirect
3+
from unitxt.inference import CrossProviderInferenceEngine
4+
from unitxt.llm_as_judge_constants import (
5+
CriteriaWithOptions,
6+
)
7+
8+
criteria = CriteriaWithOptions.from_obj(
9+
{
10+
"name": "Temperature in Fahrenheit and Celsius",
11+
"description": "In the response, if there is a numerical temperature present, is it denominated in both Fahrenheit and Celsius?",
12+
"options": [
13+
{
14+
"name": "Correct",
15+
"description": "The temperature reading is provided in both Fahrenheit and Celsius.",
16+
},
17+
{
18+
"name": "Partially Correct",
19+
"description": "The temperature reading is provided either in Fahrenheit or Celsius, but not both.",
20+
},
21+
{
22+
"name": "Incorrect",
23+
"description": "There is no numerical temperature reading in the response.",
24+
},
25+
],
26+
"option_map": {"Correct": 1.0, "Partially Correct": 0.5, "Incorrect": 0.0},
27+
"context_fields": ["question"],
28+
}
29+
)
30+
31+
32+
data = [
33+
{"question": "How is the weather?"},
34+
{"question": "How is the weather?"},
35+
{"question": "How is the weather?"},
36+
]
37+
38+
metric = EvalAssistLLMJudgeDirect(
39+
inference_engine=CrossProviderInferenceEngine(
40+
model="llama-3-3-70b-instruct",
41+
max_tokens=1024,
42+
data_classification_policy=["public"],
43+
),
44+
criteria=criteria,
45+
)
46+
47+
48+
dataset = create_dataset(
49+
task="tasks.qa.open", test_set=data, metrics=[metric], split="test"
50+
)
51+
52+
predictions = [
53+
"""On most days, the weather is warm and humid, with temperatures often soaring into the high 80s and low 90s Fahrenheit (around 31-34°C). The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""",
54+
"""On most days, the weather is warm and humid, with temperatures often soaring into the high 80s and low 90s Fahrenheit. The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""",
55+
"""On most days, the weather is warm and humid. The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""",
56+
]
57+
58+
results = evaluate(predictions=predictions, data=dataset)
59+
60+
print("Global Scores:")
61+
print(results.global_scores.summary)
62+
63+
print("Instance Scores:")
64+
print(results.instance_scores)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from unitxt import add_to_catalog
2+
from unitxt.evalassist_judge import EvalAssistLLMJudgeDirect
3+
from unitxt.inference import CrossProviderInferenceEngine
4+
5+
eval_assist_judge = EvalAssistLLMJudgeDirect(
6+
inference_engine=CrossProviderInferenceEngine(
7+
model="llama-3-3-70b-instruct",
8+
max_tokens=1024,
9+
temperature=0.0,
10+
),
11+
)
12+
13+
add_to_catalog(
14+
eval_assist_judge,
15+
"metrics.evalassist_judge.direct.watsonx.llama3_3_70b",
16+
overwrite=True,
17+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"__type__": "eval_assist_llm_judge_direct",
3+
"inference_engine": {
4+
"__type__": "cross_provider_inference_engine",
5+
"model": "llama-3-3-70b-instruct",
6+
"max_tokens": 1024,
7+
"temperature": 0.0
8+
}
9+
}

src/unitxt/evalassist_judge.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from difflib import get_close_matches
2+
from typing import Any, Dict, List, Union, cast
3+
4+
from .artifact import fetch_artifact
5+
from .dict_utils import dict_get
6+
from .error_utils import UnitxtError
7+
from .inference import (
8+
InferenceEngine,
9+
PackageRequirementsMixin,
10+
)
11+
from .llm_as_judge_constants import (
12+
Criteria,
13+
CriteriaWithOptions,
14+
)
15+
from .logging_utils import get_logger
16+
from .metric_utils import EmptyPrediction
17+
from .metrics import BulkInstanceMetric
18+
19+
logger = get_logger(__name__)
20+
21+
22+
class EvalAssistLLMJudge(BulkInstanceMetric, PackageRequirementsMixin):
23+
"""A metric class to evaluate instances using LLM as a Judge.
24+
25+
Evaluations are performed in two steps. First, the LLM is asked to generate an assessment following a CoT approach based on the criteria. Then, the same LLM is asked to select one of the available options. A summary of the general assessment can be generated for easy consumption by end users.
26+
"""
27+
28+
_requirements_list = {
29+
"evalassist": "Install huggingface package using 'pip install --upgrade evalassist",
30+
}
31+
32+
inference_engine: InferenceEngine
33+
"""The engine used for generating predictions in the different evaluation steps."""
34+
35+
context_fields: Union[str, List[str], Dict[str, str], None] = None
36+
"""Fields to be used as context. If a dict is provided, the keys are used as the final names in the prompts, while the values are used to access the context variable values in the `task_data` object (it is recommended to provide the context_fields in the Criteria `context_fields` field as this field will be deprecated in the future)."""
37+
38+
check_positional_bias: bool = False
39+
"""Flag to check for positional bias. Detecting for positional bias duplicates the amount of inference calls."""
40+
41+
criteria_field: Union[str, None] = None
42+
"""The field specifying the evaluation criteria in the `task_data` object. If the `criteria` is provided, it will take precedence."""
43+
44+
criteria: Criteria = None
45+
"""The criteria used for evaluation."""
46+
47+
def prepare(self):
48+
"""Prepares the `LLMJudge` instance by setting up context fields and evaluator name."""
49+
if self.context_fields is not None:
50+
self.context_fields = self.get_context_fields_as_dict(self.context_fields)
51+
super().prepare()
52+
53+
def before_process_multi_stream(self):
54+
"""Checks the criteria-related fields correctness before processing multiple streams.
55+
56+
Raises:
57+
UnitxtError: If both 'criteria' and 'criteria_field' are not set.
58+
"""
59+
super().before_process_multi_stream()
60+
# We check the criteria here and not in verify(), because we want catalog
61+
# may contain a partially initialized object, and verify() method
62+
# is called when creating the object and not when using it.
63+
if self.criteria is None and self.criteria_field is None:
64+
raise UnitxtError(
65+
f"You must set either the 'criteria' field of the {__class__.__name__} metric to define one criteria to evaluate on all instance, or set a 'criteria_field' of the metric to evaluate on each instance based on the criteria specified in that field of each instance."
66+
)
67+
return
68+
69+
def get_context_fields_as_dict(
70+
self, context_fields: Union[str, List[str], Dict[str, str]]
71+
):
72+
result = context_fields if context_fields else {}
73+
if isinstance(result, str):
74+
result = [result]
75+
if isinstance(result, List):
76+
result = {context_field: context_field for context_field in result}
77+
return result
78+
79+
def get_contexts(
80+
self, task_data: List[Dict[str, Any]], criteria: List[Criteria]
81+
) -> List[Dict[str, str]]:
82+
"""Extracts and parses context fields from task data.
83+
84+
Args:
85+
task_data (List[Dict[str, Any]]): The task data containing context information.
86+
criteria ( List[Criteria]): The criteria list from which to take the context fields if they weren't provided in the self.context_fields field
87+
88+
Returns:
89+
List[Dict[str, str]]: A list of parsed context dictionaries.
90+
"""
91+
parsed_contexts = []
92+
for i, td in enumerate(task_data):
93+
context_fields_for_td = self.context_fields
94+
if not context_fields_for_td and criteria[i].context_fields:
95+
context_fields_for_td = self.get_context_fields_as_dict(
96+
criteria[i].context_fields
97+
)
98+
99+
parsed_contexts.append(
100+
{
101+
context_field_name: str(dict_get(td, context_field))
102+
for context_field_name, context_field in context_fields_for_td.items()
103+
}
104+
)
105+
return parsed_contexts
106+
107+
def clean_results(self, results: Union[dict, list]):
108+
"""Cleans the results by removing `None` values and empty lists and dictionaries.
109+
110+
Args:
111+
results (Union[dict, list]): The results to clean.
112+
113+
Returns:
114+
Union[dict, list]: The cleaned results.
115+
"""
116+
if isinstance(results, list):
117+
return [self.clean_results(x) for x in results]
118+
cleaned = {
119+
k: (v if not isinstance(v, dict) else self.clean_results(v))
120+
for k, v in results.items()
121+
if v is not None and not (isinstance(v, (list, dict)) and len(v) == 0)
122+
}
123+
# Remove the dictionary itself if it becomes empty
124+
return {
125+
k: v
126+
for k, v in cleaned.items()
127+
if not (isinstance(v, dict) and len(v) == 0)
128+
}
129+
130+
def get_criteria(self, task_data, eval_count) -> List[Criteria]:
131+
"""Retrieves the evaluation criteria from the `criteria_field` or from `self`.
132+
133+
Args:
134+
task_data (List[Dict[str, Any]]): The task data containing criteria information.
135+
eval_count (int): The number of evaluations to perform.
136+
137+
Returns:
138+
List[Criteria]: A list of criteria for evaluation.
139+
140+
Raises:
141+
UnitxtError: If the criteria field is not found in the task data.
142+
"""
143+
criteria_list: List[Criteria]
144+
if self.criteria is None:
145+
if any(self.criteria_field not in td for td in task_data):
146+
raise UnitxtError(
147+
f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(cast(str, self.criteria_field), task_data[0].keys(), n=1, cutoff=0.0)[0]}'?"
148+
)
149+
logger.info(
150+
f"Reading criteria from the task_data field '{self.criteria_field}'"
151+
)
152+
criteria_list = [
153+
cast(
154+
Criteria, fetch_artifact(task_data_instance[self.criteria_field])[0]
155+
)
156+
for task_data_instance in task_data
157+
]
158+
else:
159+
logger.info(
160+
"Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions"
161+
)
162+
criteria_list = [self.criteria] * eval_count
163+
unique_criteria_names = list({criteria.name for criteria in criteria_list})
164+
165+
logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'")
166+
return criteria_list
167+
168+
def get_predictions(
169+
self,
170+
task_data: List[Dict[str, Any]],
171+
criteria: List[Criteria],
172+
predictions: List[str],
173+
) -> List[str]:
174+
if not predictions or all(
175+
(
176+
isinstance(prediction, EmptyPrediction)
177+
or prediction == str(EmptyPrediction())
178+
)
179+
for prediction in predictions
180+
):
181+
predictions_from_task_data = []
182+
for i, td in enumerate(task_data):
183+
if (
184+
criteria[i].prediction_field is not None
185+
and criteria[i].prediction_field in td
186+
):
187+
predictions_from_task_data.append(
188+
dict_get(td, criteria[i].prediction_field) # type: ignore
189+
)
190+
else:
191+
raise UnitxtError(
192+
"You must set either the predictions in the evaluate() call or specify the prediction field name to be taken from the task_data using the `Criteria`'s prediction_field field."
193+
)
194+
return predictions_from_task_data
195+
196+
return predictions
197+
198+
199+
class EvalAssistLLMJudgeDirect(EvalAssistLLMJudge):
200+
"""LLMJudgeDirect is a specialized evaluation metric that performs Direct Assessment using an LLM to score responses based on a predefined evaluation criteria.
201+
202+
Direct Assessment is an evaluation paradigm in which the LLM selects one of a
203+
predefined set of options based on an assessment criterion. This approach can
204+
be used for Likert-scale scoring (e.g., 1-5) or selecting from semantically
205+
conditioned literals (e.g., Yes/No, Pass/Fail).
206+
"""
207+
208+
criteria: CriteriaWithOptions = None
209+
"""The evaluation criteria, including a name, description, a predefined set of options and and option_map."""
210+
main_score = "llm_as_judge"
211+
"""The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise."""
212+
reduction_map = {"mean": ["llm_as_judge"]}
213+
"""A mapping used for score aggregation. By default, it will take the value of ``{'mean': [<default_main_score_name>]}`` ."""
214+
215+
def before_process_multi_stream(self):
216+
"""Ensures that the criteria is of type `CriteriaWithOptions`, raising an exception otherwise."""
217+
super().before_process_multi_stream()
218+
if self.criteria is not None and not isinstance(
219+
self.criteria, CriteriaWithOptions
220+
):
221+
raise Exception(
222+
f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'"
223+
)
224+
return
225+
226+
def __set_main_score(self, criterias: List[Criteria]):
227+
unique_criteria_names = list({criteria.name for criteria in criterias})
228+
if len(unique_criteria_names) == 1 and criterias[0].name != "":
229+
self.main_score = "_".join(criterias[0].name.lower().split(" "))
230+
self.reduction_map = {"mean": [self.main_score]}
231+
232+
def compute(
233+
self,
234+
references: List[List[str]],
235+
predictions: List[str],
236+
task_data: List[Dict[str, Any]],
237+
) -> List[Dict]:
238+
logger.info(
239+
f"Starting evaluation with judge {self.inference_engine.get_pretty_print_name()}"
240+
)
241+
from evalassist.judges import (
242+
Criteria,
243+
DirectInstance,
244+
DirectInstanceResult,
245+
DirectJudge,
246+
)
247+
248+
judge = DirectJudge(self.inference_engine)
249+
250+
evaluations_count = len(task_data)
251+
# TODO: find out how to serialize and deserialize enums
252+
criteria = self.get_criteria(task_data, evaluations_count)
253+
self.__set_main_score(criteria)
254+
predictions = self.get_predictions(task_data, criteria, predictions)
255+
contexts = self.get_contexts(task_data, criteria)
256+
eval_assist_criteria = [
257+
Criteria.from_unitxt_criteria(criterion) for criterion in criteria
258+
]
259+
260+
instances = [
261+
DirectInstance(
262+
context=context,
263+
response=prediction,
264+
)
265+
for prediction, context in zip(predictions, contexts)
266+
]
267+
268+
results: list[DirectInstanceResult] = judge(
269+
instances=instances,
270+
criteria=eval_assist_criteria,
271+
check_positional_bias=self.check_positional_bias,
272+
)
273+
274+
parsed_results: list[dict] = [
275+
{
276+
"selected_option": r.option,
277+
"explanation": r.explanation,
278+
"feedback": r.feedback if r.feedback is not None else None,
279+
"prompt": r.metadata["prompt"],
280+
"positional_bias": r.positional_bias.detected
281+
if r.positional_bias is not None
282+
else None,
283+
self.main_score: r.score if r.score is not None else 0.0,
284+
}
285+
for r in results
286+
]
287+
288+
parsed_results = [
289+
{
290+
f"{self.main_score}_{k}" if k != self.main_score else self.main_score: v
291+
for k, v in r.items()
292+
}
293+
for r in parsed_results
294+
]
295+
296+
return self.clean_results(parsed_results)

0 commit comments

Comments
 (0)