|
| 1 | +from difflib import get_close_matches |
| 2 | +from typing import Any, Dict, List, Union, cast |
| 3 | + |
| 4 | +from .artifact import fetch_artifact |
| 5 | +from .dict_utils import dict_get |
| 6 | +from .error_utils import UnitxtError |
| 7 | +from .inference import ( |
| 8 | + InferenceEngine, |
| 9 | + PackageRequirementsMixin, |
| 10 | +) |
| 11 | +from .llm_as_judge_constants import ( |
| 12 | + Criteria, |
| 13 | + CriteriaWithOptions, |
| 14 | +) |
| 15 | +from .logging_utils import get_logger |
| 16 | +from .metric_utils import EmptyPrediction |
| 17 | +from .metrics import BulkInstanceMetric |
| 18 | + |
| 19 | +logger = get_logger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +class EvalAssistLLMJudge(BulkInstanceMetric, PackageRequirementsMixin): |
| 23 | + """A metric class to evaluate instances using LLM as a Judge. |
| 24 | +
|
| 25 | + Evaluations are performed in two steps. First, the LLM is asked to generate an assessment following a CoT approach based on the criteria. Then, the same LLM is asked to select one of the available options. A summary of the general assessment can be generated for easy consumption by end users. |
| 26 | + """ |
| 27 | + |
| 28 | + _requirements_list = { |
| 29 | + "evalassist": "Install huggingface package using 'pip install --upgrade evalassist", |
| 30 | + } |
| 31 | + |
| 32 | + inference_engine: InferenceEngine |
| 33 | + """The engine used for generating predictions in the different evaluation steps.""" |
| 34 | + |
| 35 | + context_fields: Union[str, List[str], Dict[str, str], None] = None |
| 36 | + """Fields to be used as context. If a dict is provided, the keys are used as the final names in the prompts, while the values are used to access the context variable values in the `task_data` object (it is recommended to provide the context_fields in the Criteria `context_fields` field as this field will be deprecated in the future).""" |
| 37 | + |
| 38 | + check_positional_bias: bool = False |
| 39 | + """Flag to check for positional bias. Detecting for positional bias duplicates the amount of inference calls.""" |
| 40 | + |
| 41 | + criteria_field: Union[str, None] = None |
| 42 | + """The field specifying the evaluation criteria in the `task_data` object. If the `criteria` is provided, it will take precedence.""" |
| 43 | + |
| 44 | + criteria: Criteria = None |
| 45 | + """The criteria used for evaluation.""" |
| 46 | + |
| 47 | + def prepare(self): |
| 48 | + """Prepares the `LLMJudge` instance by setting up context fields and evaluator name.""" |
| 49 | + if self.context_fields is not None: |
| 50 | + self.context_fields = self.get_context_fields_as_dict(self.context_fields) |
| 51 | + super().prepare() |
| 52 | + |
| 53 | + def before_process_multi_stream(self): |
| 54 | + """Checks the criteria-related fields correctness before processing multiple streams. |
| 55 | +
|
| 56 | + Raises: |
| 57 | + UnitxtError: If both 'criteria' and 'criteria_field' are not set. |
| 58 | + """ |
| 59 | + super().before_process_multi_stream() |
| 60 | + # We check the criteria here and not in verify(), because we want catalog |
| 61 | + # may contain a partially initialized object, and verify() method |
| 62 | + # is called when creating the object and not when using it. |
| 63 | + if self.criteria is None and self.criteria_field is None: |
| 64 | + raise UnitxtError( |
| 65 | + f"You must set either the 'criteria' field of the {__class__.__name__} metric to define one criteria to evaluate on all instance, or set a 'criteria_field' of the metric to evaluate on each instance based on the criteria specified in that field of each instance." |
| 66 | + ) |
| 67 | + return |
| 68 | + |
| 69 | + def get_context_fields_as_dict( |
| 70 | + self, context_fields: Union[str, List[str], Dict[str, str]] |
| 71 | + ): |
| 72 | + result = context_fields if context_fields else {} |
| 73 | + if isinstance(result, str): |
| 74 | + result = [result] |
| 75 | + if isinstance(result, List): |
| 76 | + result = {context_field: context_field for context_field in result} |
| 77 | + return result |
| 78 | + |
| 79 | + def get_contexts( |
| 80 | + self, task_data: List[Dict[str, Any]], criteria: List[Criteria] |
| 81 | + ) -> List[Dict[str, str]]: |
| 82 | + """Extracts and parses context fields from task data. |
| 83 | +
|
| 84 | + Args: |
| 85 | + task_data (List[Dict[str, Any]]): The task data containing context information. |
| 86 | + criteria ( List[Criteria]): The criteria list from which to take the context fields if they weren't provided in the self.context_fields field |
| 87 | +
|
| 88 | + Returns: |
| 89 | + List[Dict[str, str]]: A list of parsed context dictionaries. |
| 90 | + """ |
| 91 | + parsed_contexts = [] |
| 92 | + for i, td in enumerate(task_data): |
| 93 | + context_fields_for_td = self.context_fields |
| 94 | + if not context_fields_for_td and criteria[i].context_fields: |
| 95 | + context_fields_for_td = self.get_context_fields_as_dict( |
| 96 | + criteria[i].context_fields |
| 97 | + ) |
| 98 | + |
| 99 | + parsed_contexts.append( |
| 100 | + { |
| 101 | + context_field_name: str(dict_get(td, context_field)) |
| 102 | + for context_field_name, context_field in context_fields_for_td.items() |
| 103 | + } |
| 104 | + ) |
| 105 | + return parsed_contexts |
| 106 | + |
| 107 | + def clean_results(self, results: Union[dict, list]): |
| 108 | + """Cleans the results by removing `None` values and empty lists and dictionaries. |
| 109 | +
|
| 110 | + Args: |
| 111 | + results (Union[dict, list]): The results to clean. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + Union[dict, list]: The cleaned results. |
| 115 | + """ |
| 116 | + if isinstance(results, list): |
| 117 | + return [self.clean_results(x) for x in results] |
| 118 | + cleaned = { |
| 119 | + k: (v if not isinstance(v, dict) else self.clean_results(v)) |
| 120 | + for k, v in results.items() |
| 121 | + if v is not None and not (isinstance(v, (list, dict)) and len(v) == 0) |
| 122 | + } |
| 123 | + # Remove the dictionary itself if it becomes empty |
| 124 | + return { |
| 125 | + k: v |
| 126 | + for k, v in cleaned.items() |
| 127 | + if not (isinstance(v, dict) and len(v) == 0) |
| 128 | + } |
| 129 | + |
| 130 | + def get_criteria(self, task_data, eval_count) -> List[Criteria]: |
| 131 | + """Retrieves the evaluation criteria from the `criteria_field` or from `self`. |
| 132 | +
|
| 133 | + Args: |
| 134 | + task_data (List[Dict[str, Any]]): The task data containing criteria information. |
| 135 | + eval_count (int): The number of evaluations to perform. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + List[Criteria]: A list of criteria for evaluation. |
| 139 | +
|
| 140 | + Raises: |
| 141 | + UnitxtError: If the criteria field is not found in the task data. |
| 142 | + """ |
| 143 | + criteria_list: List[Criteria] |
| 144 | + if self.criteria is None: |
| 145 | + if any(self.criteria_field not in td for td in task_data): |
| 146 | + raise UnitxtError( |
| 147 | + f"The criteria field `{self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(cast(str, self.criteria_field), task_data[0].keys(), n=1, cutoff=0.0)[0]}'?" |
| 148 | + ) |
| 149 | + logger.info( |
| 150 | + f"Reading criteria from the task_data field '{self.criteria_field}'" |
| 151 | + ) |
| 152 | + criteria_list = [ |
| 153 | + cast( |
| 154 | + Criteria, fetch_artifact(task_data_instance[self.criteria_field])[0] |
| 155 | + ) |
| 156 | + for task_data_instance in task_data |
| 157 | + ] |
| 158 | + else: |
| 159 | + logger.info( |
| 160 | + "Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions" |
| 161 | + ) |
| 162 | + criteria_list = [self.criteria] * eval_count |
| 163 | + unique_criteria_names = list({criteria.name for criteria in criteria_list}) |
| 164 | + |
| 165 | + logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'") |
| 166 | + return criteria_list |
| 167 | + |
| 168 | + def get_predictions( |
| 169 | + self, |
| 170 | + task_data: List[Dict[str, Any]], |
| 171 | + criteria: List[Criteria], |
| 172 | + predictions: List[str], |
| 173 | + ) -> List[str]: |
| 174 | + if not predictions or all( |
| 175 | + ( |
| 176 | + isinstance(prediction, EmptyPrediction) |
| 177 | + or prediction == str(EmptyPrediction()) |
| 178 | + ) |
| 179 | + for prediction in predictions |
| 180 | + ): |
| 181 | + predictions_from_task_data = [] |
| 182 | + for i, td in enumerate(task_data): |
| 183 | + if ( |
| 184 | + criteria[i].prediction_field is not None |
| 185 | + and criteria[i].prediction_field in td |
| 186 | + ): |
| 187 | + predictions_from_task_data.append( |
| 188 | + dict_get(td, criteria[i].prediction_field) # type: ignore |
| 189 | + ) |
| 190 | + else: |
| 191 | + raise UnitxtError( |
| 192 | + "You must set either the predictions in the evaluate() call or specify the prediction field name to be taken from the task_data using the `Criteria`'s prediction_field field." |
| 193 | + ) |
| 194 | + return predictions_from_task_data |
| 195 | + |
| 196 | + return predictions |
| 197 | + |
| 198 | + |
| 199 | +class EvalAssistLLMJudgeDirect(EvalAssistLLMJudge): |
| 200 | + """LLMJudgeDirect is a specialized evaluation metric that performs Direct Assessment using an LLM to score responses based on a predefined evaluation criteria. |
| 201 | +
|
| 202 | + Direct Assessment is an evaluation paradigm in which the LLM selects one of a |
| 203 | + predefined set of options based on an assessment criterion. This approach can |
| 204 | + be used for Likert-scale scoring (e.g., 1-5) or selecting from semantically |
| 205 | + conditioned literals (e.g., Yes/No, Pass/Fail). |
| 206 | + """ |
| 207 | + |
| 208 | + criteria: CriteriaWithOptions = None |
| 209 | + """The evaluation criteria, including a name, description, a predefined set of options and and option_map.""" |
| 210 | + main_score = "llm_as_judge" |
| 211 | + """The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise.""" |
| 212 | + reduction_map = {"mean": ["llm_as_judge"]} |
| 213 | + """A mapping used for score aggregation. By default, it will take the value of ``{'mean': [<default_main_score_name>]}`` .""" |
| 214 | + |
| 215 | + def before_process_multi_stream(self): |
| 216 | + """Ensures that the criteria is of type `CriteriaWithOptions`, raising an exception otherwise.""" |
| 217 | + super().before_process_multi_stream() |
| 218 | + if self.criteria is not None and not isinstance( |
| 219 | + self.criteria, CriteriaWithOptions |
| 220 | + ): |
| 221 | + raise Exception( |
| 222 | + f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'" |
| 223 | + ) |
| 224 | + return |
| 225 | + |
| 226 | + def __set_main_score(self, criterias: List[Criteria]): |
| 227 | + unique_criteria_names = list({criteria.name for criteria in criterias}) |
| 228 | + if len(unique_criteria_names) == 1 and criterias[0].name != "": |
| 229 | + self.main_score = "_".join(criterias[0].name.lower().split(" ")) |
| 230 | + self.reduction_map = {"mean": [self.main_score]} |
| 231 | + |
| 232 | + def compute( |
| 233 | + self, |
| 234 | + references: List[List[str]], |
| 235 | + predictions: List[str], |
| 236 | + task_data: List[Dict[str, Any]], |
| 237 | + ) -> List[Dict]: |
| 238 | + logger.info( |
| 239 | + f"Starting evaluation with judge {self.inference_engine.get_pretty_print_name()}" |
| 240 | + ) |
| 241 | + from evalassist.judges import ( |
| 242 | + Criteria, |
| 243 | + DirectInstance, |
| 244 | + DirectInstanceResult, |
| 245 | + DirectJudge, |
| 246 | + ) |
| 247 | + |
| 248 | + judge = DirectJudge(self.inference_engine) |
| 249 | + |
| 250 | + evaluations_count = len(task_data) |
| 251 | + # TODO: find out how to serialize and deserialize enums |
| 252 | + criteria = self.get_criteria(task_data, evaluations_count) |
| 253 | + self.__set_main_score(criteria) |
| 254 | + predictions = self.get_predictions(task_data, criteria, predictions) |
| 255 | + contexts = self.get_contexts(task_data, criteria) |
| 256 | + eval_assist_criteria = [ |
| 257 | + Criteria.from_unitxt_criteria(criterion) for criterion in criteria |
| 258 | + ] |
| 259 | + |
| 260 | + instances = [ |
| 261 | + DirectInstance( |
| 262 | + context=context, |
| 263 | + response=prediction, |
| 264 | + ) |
| 265 | + for prediction, context in zip(predictions, contexts) |
| 266 | + ] |
| 267 | + |
| 268 | + results: list[DirectInstanceResult] = judge( |
| 269 | + instances=instances, |
| 270 | + criteria=eval_assist_criteria, |
| 271 | + check_positional_bias=self.check_positional_bias, |
| 272 | + ) |
| 273 | + |
| 274 | + parsed_results: list[dict] = [ |
| 275 | + { |
| 276 | + "selected_option": r.option, |
| 277 | + "explanation": r.explanation, |
| 278 | + "feedback": r.feedback if r.feedback is not None else None, |
| 279 | + "prompt": r.metadata["prompt"], |
| 280 | + "positional_bias": r.positional_bias.detected |
| 281 | + if r.positional_bias is not None |
| 282 | + else None, |
| 283 | + self.main_score: r.score if r.score is not None else 0.0, |
| 284 | + } |
| 285 | + for r in results |
| 286 | + ] |
| 287 | + |
| 288 | + parsed_results = [ |
| 289 | + { |
| 290 | + f"{self.main_score}_{k}" if k != self.main_score else self.main_score: v |
| 291 | + for k, v in r.items() |
| 292 | + } |
| 293 | + for r in parsed_results |
| 294 | + ] |
| 295 | + |
| 296 | + return self.clean_results(parsed_results) |
0 commit comments