diff --git a/.github/workflows/catalog_consistency.yml b/.github/workflows/catalog_consistency.yml index 4b42a8843b..96eab4c6e5 100644 --- a/.github/workflows/catalog_consistency.yml +++ b/.github/workflows/catalog_consistency.yml @@ -30,7 +30,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - run: curl -LsSf https://astral.sh/uv/install.sh | sh - run: uv pip install --system -e ".[tests]" diff --git a/examples/evaluate_evalassist_judge.py b/examples/evaluate_evalassist_judge.py new file mode 100644 index 0000000000..7f192ecc0e --- /dev/null +++ b/examples/evaluate_evalassist_judge.py @@ -0,0 +1,64 @@ +from unitxt.api import create_dataset, evaluate +from unitxt.evalassist_judge import EvalAssistLLMJudgeDirect +from unitxt.inference import CrossProviderInferenceEngine +from unitxt.llm_as_judge_constants import ( + CriteriaWithOptions, +) + +criteria = CriteriaWithOptions.from_obj( + { + "name": "Temperature in Fahrenheit and Celsius", + "description": "In the response, if there is a numerical temperature present, is it denominated in both Fahrenheit and Celsius?", + "options": [ + { + "name": "Correct", + "description": "The temperature reading is provided in both Fahrenheit and Celsius.", + }, + { + "name": "Partially Correct", + "description": "The temperature reading is provided either in Fahrenheit or Celsius, but not both.", + }, + { + "name": "Incorrect", + "description": "There is no numerical temperature reading in the response.", + }, + ], + "option_map": {"Correct": 1.0, "Partially Correct": 0.5, "Incorrect": 0.0}, + "context_fields": ["question"], + } +) + + +data = [ + {"question": "How is the weather?"}, + {"question": "How is the weather?"}, + {"question": "How is the weather?"}, +] + +metric = EvalAssistLLMJudgeDirect( + inference_engine=CrossProviderInferenceEngine( + model="llama-3-3-70b-instruct", + max_tokens=1024, + data_classification_policy=["public"], + ), + criteria=criteria, +) + + +dataset = create_dataset( + task="tasks.qa.open", test_set=data, metrics=[metric], split="test" +) + +predictions = [ + """On most days, the weather is warm and humid, with temperatures often soaring into the high 80s and low 90s Fahrenheit (around 31-34°C). The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""", + """On most days, the weather is warm and humid, with temperatures often soaring into the high 80s and low 90s Fahrenheit. The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""", + """On most days, the weather is warm and humid. The dense foliage of the jungle acts as a natural air conditioner, keeping the temperature relatively stable and comfortable for the inhabitants.""", +] + +results = evaluate(predictions=predictions, data=dataset) + +print("Global Scores:") +print(results.global_scores.summary) + +print("Instance Scores:") +print(results.instance_scores) diff --git a/examples/evaluate_faithfulness_metrics.py b/examples/evaluate_faithfulness_metrics.py new file mode 100644 index 0000000000..045027a9cc --- /dev/null +++ b/examples/evaluate_faithfulness_metrics.py @@ -0,0 +1,151 @@ +import json + +import pandas as pd +import unitxt +from unitxt.api import evaluate, load_dataset +from unitxt.benchmark import Benchmark +from unitxt.inference import MetricInferenceEngine +from unitxt.standard import DatasetRecipe +from unitxt.templates import InputOutputTemplate + +unitxt.settings.allow_unverified_code = True +unitxt.settings.dataset_cache_default = True + +card_subsets = [ + "covidqa", + "cuad", + "delucionqa", + "emanual", + "expertqa", + "finqa", + "hagrid", + "hotpotqa", + "msmarco", + "pubmedqa", + "tatqa", + # "techqa" # Fails due to bad char in text +] + +# card_subsets = ["covidqa"] +card = "cards.rag_eval.faithfulness.ragbench" + +template = InputOutputTemplate( + output_format="{number_val}", + input_format="{question}", # "CONTEXTS:{contexts}\n\n\n\QUESTION:{question}\n\n\nANSWER:{answer}", + postprocessors=["processors.cast_to_float_return_0_5_if_failed"], +) + +subsets = { + card_subset: DatasetRecipe( + card=f"{card}.{card_subset}", + template=template, + metrics=[ + "metrics.f1_binary", + "metrics.f1_binary[average=macro,score_prefix=macro_]", + ], + ) + for card_subset in card_subsets +} + +benchmark = Benchmark( + format="formats.empty", + max_samples_per_subset=40, + loader_limit=300, + subsets=subsets, +) + +dataset = load_dataset( + benchmark, + split="test", +) +for instance in dataset: + task_data = json.loads(instance["task_data"]) + + +metrics_to_score_names = {} + +criterion = "metrics.llm_as_judge.direct.criteria.reference_document_faithfulness" +llm_as_judge_metric = f"metrics.llm_as_judge.direct.rits.llama3_3_70b[check_positional_bias=False,criteria={criterion}, context_fields=[contexts,question]]" +llm_score_name = "reference_document_faithfulness" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + +llm_as_judge_metric = f"metrics.llm_as_judge.direct.watsonx.llama3_3_70b[check_positional_bias=False,criteria={criterion}, context_fields=[contexts,question]]" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + +llm_as_judge_metric = f"metrics.llm_as_judge.evalassist.direct.rits.llama3_3_70b[criteria={criterion},context_fields=[contexts,question]]" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + +llm_as_judge_metric = f"metrics.llm_as_judge.evalassist.direct.watsonx.llama3_3_70b[criteria={criterion},context_fields=[contexts,question]]" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + +criterion = "metrics.llm_as_judge.direct.criteria.reference_document_faithfulness2" +llm_score_name = "reference_document_faithfulness2" +llm_as_judge_metric = f"metrics.llm_as_judge.evalassist.direct.rits.llama3_3_70b[criteria={criterion},context_fields=[contexts,question]]" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + +llm_as_judge_metric = f"metrics.llm_as_judge.evalassist.direct.watsonx.llama3_3_70b[criteria={criterion},context_fields=[contexts,question]]" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name + + +llm_as_judge_metric = ( + "metrics.rag.external_rag.faithfulness.llama_3_3_70b_instruct_watsonx_judge" +) +llm_score_name = "faithfulness_judge" +metrics_to_score_names[llm_as_judge_metric] = llm_score_name +metrics_to_score_names["all_one"] = "score" +df = pd.DataFrame( + columns=[ + "metric", + "f1_macro", + "f1_faithful", + "f1_not_faithful", + "num_of_instances", + ] +) + +for metric, score_name in metrics_to_score_names.items(): + # print(json.dumps(task_data,indent=4)) + # print(json.dumps(instance,indent=4)) + # print(instance["references"]) + + if metric == "all_one": + new_predictions = [1.0] * len(dataset) + else: + model = MetricInferenceEngine(metric=metric, prediction_field="answer") + predictions = model(dataset) + new_predictions = [prediction[score_name] for prediction in predictions] + results = evaluate( + predictions=new_predictions, data=dataset, calc_confidence_intervals=False + ) + + sums = {} + counts = {} + + for _, inner_dict in results.subsets_scores.items(): + if isinstance(inner_dict, dict): + for key, value in inner_dict.items(): + if isinstance(value, float): + sums[key] = sums.get(key, 0) + value + counts[key] = counts.get(key, 0) + 1 + # + averages = {key: sums[key] / counts[key] for key in sums} + + df.loc[len(df)] = [ + str(metric), + averages["macro_f1_binary"], + averages["f1_binary"], + averages["f1_binary_neg"], + results.global_scores["num_of_instances"], + ] + + print("Instance Results:") + print(results.instance_scores.summary) + + print("Subsets Results (details):") + print(results.subsets_scores) + + print("Subsets Results :") + print(results.subsets_scores.summary) + + df = df.round(decimals=2) + print(df.to_markdown()) diff --git a/prepare/metrics/llm_as_judge/evalassist_judge.py b/prepare/metrics/llm_as_judge/evalassist_judge.py new file mode 100644 index 0000000000..e8f3ade381 --- /dev/null +++ b/prepare/metrics/llm_as_judge/evalassist_judge.py @@ -0,0 +1,24 @@ +from unitxt import add_to_catalog +from unitxt.evalassist_judge import EvalAssistLLMJudgeDirect +from unitxt.inference import CrossProviderInferenceEngine + +for provider in ["watsonx", "rits"]: + for model in ["llama-3-3-70b-instruct"]: + eval_assist_judge = EvalAssistLLMJudgeDirect( + inference_engine=CrossProviderInferenceEngine( + provider=provider, + model=model, + max_tokens=1024, + temperature=0.0, + ) + ) + if model == "llama-3-3-70b-instruct": + catalog_model = "llama3_3_70b" + else: + raise ValueError(f"Model {model} not supported") + + add_to_catalog( + eval_assist_judge, + f"metrics.llm_as_judge.evalassist.direct.{provider}.{catalog_model}", + overwrite=True, + ) diff --git a/pyproject.toml b/pyproject.toml index 5a7db6c150..2862354c82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,8 @@ tests = [ "sqlparse", "diskcache", "pydantic", - "jsonschema_rs" + "jsonschema_rs", + "evalassist" ] ui = [ "gradio", diff --git a/src/unitxt/catalog/metrics/llm_as_judge/direct/criteria/reference_document_faithfulness2.json b/src/unitxt/catalog/metrics/llm_as_judge/direct/criteria/reference_document_faithfulness2.json new file mode 100644 index 0000000000..6a99977dbb --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/direct/criteria/reference_document_faithfulness2.json @@ -0,0 +1,37 @@ +{ + "__type__": "criteria_with_options", + "name": "reference_document_faithfulness2", + "description": "\n Is the prediction grounded in the reference document?\n\n To be grounded in the reference document, all the information of the prediction must either be present in the reference documentor deducible from the reference document.\n\nBase your answer only on the information in the reference document If the prediction is correct but not present in the reference document then it is not grounded.\n ", + "prediction_field": "response", + "context_fields": [ + "reference_document" + ], + "options": [ + { + "__type__": "criteria_option", + "name": "Completely grounded", + "description": "The prediction is fully grounded in the reference document." + }, + { + "__type__": "criteria_option", + "name": "Mostly grounded", + "description": "the vast majority of the information in the prediction is grounded in the reference document, but there is a small or negligible part of the prediction which is not present in the reference document" + }, + { + "__type__": "criteria_option", + "name": "Somewhat grounded", + "description": "Some of the information in the prediction is grounded in the reference document." + }, + { + "__type__": "criteria_option", + "name": "Not grounded", + "description": "Most or all of the information in the prediction is not grounded in the reference documemnt" + } + ], + "option_map": { + "Completely Grounded": 1.0, + "Mostly grounded": 0.75, + "Somewhat grounded": 0.25, + "Not grounded": 0.0 + } +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/rits/llama3_3_70b.json b/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/rits/llama3_3_70b.json new file mode 100644 index 0000000000..1ff922a9f6 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/rits/llama3_3_70b.json @@ -0,0 +1,10 @@ +{ + "__type__": "eval_assist_llm_judge_direct", + "inference_engine": { + "__type__": "cross_provider_inference_engine", + "provider": "rits", + "model": "llama-3-3-70b-instruct", + "max_tokens": 1024, + "temperature": 0.0 + } +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/watsonx/llama3_3_70b.json b/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/watsonx/llama3_3_70b.json new file mode 100644 index 0000000000..1a4d23487a --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/evalassist/direct/watsonx/llama3_3_70b.json @@ -0,0 +1,10 @@ +{ + "__type__": "eval_assist_llm_judge_direct", + "inference_engine": { + "__type__": "cross_provider_inference_engine", + "provider": "watsonx", + "model": "llama-3-3-70b-instruct", + "max_tokens": 1024, + "temperature": 0.0 + } +} diff --git a/src/unitxt/evalassist_judge.py b/src/unitxt/evalassist_judge.py new file mode 100644 index 0000000000..8d4d0dc920 --- /dev/null +++ b/src/unitxt/evalassist_judge.py @@ -0,0 +1,290 @@ +from difflib import get_close_matches +from typing import Any, Dict, List, Union, cast + +from .artifact import fetch_artifact +from .dict_utils import dict_get +from .error_utils import UnitxtError +from .inference import ( + InferenceEngine, + PackageRequirementsMixin, +) +from .llm_as_judge_constants import ( + Criteria, + CriteriaWithOptions, +) +from .logging_utils import get_logger +from .metric_utils import EmptyPrediction +from .metrics import BulkInstanceMetric + +logger = get_logger(__name__) + + +class EvalAssistLLMJudge(BulkInstanceMetric, PackageRequirementsMixin): + """A metric class to evaluate instances using LLM as a Judge. + + Evaluations are performed in two steps. First, the LLM is asked to generate an assessment following a CoT approach based on the criteria. Then, the same LLM is asked to select one of the available options. A summary of the general assessment can be generated for easy consumption by end users. + """ + + _requirements_list = { + "evalassist": "Install huggingface package using 'pip install --upgrade evalassist", + } + + inference_engine: InferenceEngine + """The engine used for generating predictions in the different evaluation steps.""" + + context_fields: Union[str, List[str], Dict[str, str], None] = None + """Fields to be used as context. If a dict is provided, the keys are used as the final names in the prompts, while the values are used to access the context variable values in the `task_data` object (it is recommended to provide the context_fields in the Criteria `context_fields` field as this field will be deprecated in the future).""" + + criteria_field: Union[str, None] = None + """The field specifying the evaluation criteria in the `task_data` object. If the `criteria` is provided, it will take precedence.""" + + criteria: Criteria = None + """The criteria used for evaluation.""" + + def prepare(self): + """Prepares the `LLMJudge` instance by setting up context fields and evaluator name.""" + if self.context_fields is not None: + self.context_fields = self.get_context_fields_as_dict(self.context_fields) + super().prepare() + + def before_process_multi_stream(self): + """Checks the criteria-related fields correctness before processing multiple streams. + + Raises: + UnitxtError: If both 'criteria' and 'criteria_field' are not set. + """ + super().before_process_multi_stream() + # We check the criteria here and not in verify(), because we want catalog + # may contain a partially initialized object, and verify() method + # is called when creating the object and not when using it. + if self.criteria is None and self.criteria_field is None: + raise UnitxtError( + f"You must set either the 'criteria' field of the {__class__.__name__} metric to define one criteria to evaluate on all instance, or set a 'criteria_field' of the metric to evaluate on each instance based on the criteria specified in that field of each instance." + ) + return + + def get_context_fields_as_dict( + self, context_fields: Union[str, List[str], Dict[str, str]] + ): + result = context_fields if context_fields else {} + if isinstance(result, str): + result = [result] + if isinstance(result, List): + result = {context_field: context_field for context_field in result} + return result + + def get_contexts( + self, task_data: List[Dict[str, Any]], criteria: List[Criteria] + ) -> List[Dict[str, str]]: + """Extracts and parses context fields from task data. + + Args: + task_data (List[Dict[str, Any]]): The task data containing context information. + criteria ( List[Criteria]): The criteria list from which to take the context fields if they weren't provided in the self.context_fields field + + Returns: + List[Dict[str, str]]: A list of parsed context dictionaries. + """ + parsed_contexts = [] + for i, td in enumerate(task_data): + context_fields_for_td = self.context_fields + if not context_fields_for_td and criteria[i].context_fields: + context_fields_for_td = self.get_context_fields_as_dict( + criteria[i].context_fields + ) + + parsed_contexts.append( + { + context_field_name: str(dict_get(td, context_field)) + for context_field_name, context_field in context_fields_for_td.items() + } + ) + return parsed_contexts + + def clean_results(self, results: Union[dict, list]): + """Cleans the results by removing `None` values and empty lists and dictionaries. + + Args: + results (Union[dict, list]): The results to clean. + + Returns: + Union[dict, list]: The cleaned results. + """ + if isinstance(results, list): + return [self.clean_results(x) for x in results] + cleaned = { + k: (v if not isinstance(v, dict) else self.clean_results(v)) + for k, v in results.items() + if v is not None and not (isinstance(v, (list, dict)) and len(v) == 0) + } + # Remove the dictionary itself if it becomes empty + return { + k: v + for k, v in cleaned.items() + if not (isinstance(v, dict) and len(v) == 0) + } + + def get_criteria(self, task_data, eval_count) -> List[Criteria]: + """Retrieves the evaluation criteria from the `criteria_field` or from `self`. + + Args: + task_data (List[Dict[str, Any]]): The task data containing criteria information. + eval_count (int): The number of evaluations to perform. + + Returns: + List[Criteria]: A list of criteria for evaluation. + + Raises: + UnitxtError: If the criteria field is not found in the task data. + """ + criteria_list: List[Criteria] + if self.criteria is None: + if any(self.criteria_field not in td for td in task_data): + raise UnitxtError( + f"The criteria field {self.criteria_field}` required for {__class__.__name__} is not found in instance. Perhaps you meant '{get_close_matches(cast(str, self.criteria_field), task_data[0].keys(), n=1, cutoff=0.0)[0]}'?" + ) + logger.info( + f"Reading criteria from the task_data field '{self.criteria_field}'" + ) + criteria_list = [ + cast( + Criteria, fetch_artifact(task_data_instance[self.criteria_field])[0] + ) + for task_data_instance in task_data + ] + else: + logger.info( + "Reading criteria from self. Criteria is a single CriteriaWithOptions, replicating it for all predictions" + ) + criteria_list = [self.criteria] * eval_count + unique_criteria_names = list({criteria.name for criteria in criteria_list}) + + logger.info(f"Criteria names are '{', '.join(unique_criteria_names)}'") + return criteria_list + + def get_predictions( + self, + task_data: List[Dict[str, Any]], + criteria: List[Criteria], + predictions: List[str], + ) -> List[str]: + if not predictions or all( + ( + isinstance(prediction, EmptyPrediction) + or prediction == str(EmptyPrediction()) + ) + for prediction in predictions + ): + predictions_from_task_data = [] + for i, td in enumerate(task_data): + if ( + criteria[i].prediction_field is not None + and criteria[i].prediction_field in td + ): + predictions_from_task_data.append( + dict_get(td, criteria[i].prediction_field) # type: ignore + ) + else: + raise UnitxtError( + "You must set either the predictions in the evaluate() call or specify the prediction field name to be taken from the task_data using the `Criteria`'s prediction_field field." + ) + return predictions_from_task_data + + return predictions + + +class EvalAssistLLMJudgeDirect(EvalAssistLLMJudge): + """LLMJudgeDirect is a specialized evaluation metric that performs Direct Assessment using an LLM to score responses based on a predefined evaluation criteria. + + Direct Assessment is an evaluation paradigm in which the LLM selects one of a + predefined set of options based on an assessment criterion. This approach can + be used for Likert-scale scoring (e.g., 1-5) or selecting from semantically + conditioned literals (e.g., Yes/No, Pass/Fail). + """ + + criteria: CriteriaWithOptions = None + """The evaluation criteria, including a name, description, a predefined set of options and and option_map.""" + main_score = "llm_as_judge" + """The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise.""" + reduction_map = {"mean": ["llm_as_judge"]} + """A mapping used for score aggregation. By default, it will take the value of ``{'mean': []}`` .""" + + def before_process_multi_stream(self): + """Ensures that the criteria is of type `CriteriaWithOptions`, raising an exception otherwise.""" + super().before_process_multi_stream() + if self.criteria is not None and not isinstance( + self.criteria, CriteriaWithOptions + ): + raise Exception( + f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'" + ) + return + + def __set_main_score(self, criterias: List[Criteria]): + unique_criteria_names = list({criteria.name for criteria in criterias}) + if len(unique_criteria_names) == 1 and criterias[0].name != "": + self.main_score = "_".join(criterias[0].name.lower().split(" ")) + self.reduction_map = {"mean": [self.main_score]} + + def compute( + self, + references: List[List[str]], + predictions: List[str], + task_data: List[Dict[str, Any]], + ) -> List[Dict]: + logger.info( + f"Starting evaluation with judge {self.inference_engine.get_pretty_print_name()}" + ) + from evalassist.judges import ( + Criteria, + DirectInstance, + DirectInstanceResult, + DirectJudge, + ) + + judge = DirectJudge(self.inference_engine) + + evaluations_count = len(task_data) + # TODO: find out how to serialize and deserialize enums + criteria = self.get_criteria(task_data, evaluations_count) + self.__set_main_score(criteria) + predictions = self.get_predictions(task_data, criteria, predictions) + contexts = self.get_contexts(task_data, criteria) + eval_assist_criteria = [ + Criteria.from_unitxt_criteria(criterion) for criterion in criteria + ] + + instances = [ + DirectInstance( + context=context, + response=prediction, + ) + for prediction, context in zip(predictions, contexts) + ] + + results: list[DirectInstanceResult] = judge( + instances=instances, criteria=eval_assist_criteria + ) + parsed_results: list[dict] = [ + { + "selected_option": r.selected_option, + "explanation": r.explanation, + "feedback": r.feedback if r.feedback is not None else None, + "prompt": r.metadata["prompt"], + "positional_bias": r.positional_bias.detected + if r.positional_bias is not None + else None, + self.main_score: r.score if r.score is not None else 0.0, + } + for r in results + ] + + parsed_results = [ + { + f"{self.main_score}_{k}" if k != self.main_score else self.main_score: v + for k, v in r.items() + } + for r in parsed_results + ] + + return self.clean_results(parsed_results) diff --git a/src/unitxt/llm_as_judge_constants.py b/src/unitxt/llm_as_judge_constants.py index 7097f14715..8688839ed4 100644 --- a/src/unitxt/llm_as_judge_constants.py +++ b/src/unitxt/llm_as_judge_constants.py @@ -1173,6 +1173,42 @@ class DirectCriteriaCatalogEnum(Enum): }, ) + REFERENCE_DOCUMENT_FAITHFULNESS2 = CriteriaWithOptions( + name="reference_document_faithfulness2", + description=""" + Is the prediction grounded in the reference document? + + To be grounded in the reference document, all the information of the prediction must either be present in the reference documentor deducible from the reference document.\n +Base your answer only on the information in the reference document If the prediction is correct but not present in the reference document then it is not grounded. + """, + prediction_field="response", + context_fields=["reference_document"], + options=[ + CriteriaOption( + name="Completely grounded", + description="The prediction is fully grounded in the reference document.", + ), + CriteriaOption( + name="Mostly grounded", + description="the vast majority of the information in the prediction is grounded in the reference document, but there is a small or negligible part of the prediction which is not present in the reference document", + ), + CriteriaOption( + name="Somewhat grounded", + description="Some of the information in the prediction is grounded in the reference document.", + ), + CriteriaOption( + name="Not grounded", + description="Most or all of the information in the prediction is not grounded in the reference documemnt", + ), + ], + option_map={ + "Completely Grounded": 1.0, + "Mostly grounded": 0.75, + "Somewhat grounded": 0.25, + "Not grounded": 0.0, + }, + ) + GENERAL_HARM_USER_MESSAGE = CriteriaWithOptions( name="user_message_general_harm", description="The User message is considered universally harmful.", diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index d6f5de9609..fde53924dc 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -3240,7 +3240,7 @@ class F1Binary(GlobalMetric): process_single_instances = False main_score = "f1_binary" - average = None + average: str = None threshold = 0.5 prediction_type = Union[float, int] _metric = None @@ -3275,6 +3275,7 @@ def compute( y_pred=int_predictions, labels=[0, 1], average=self.average, + zero_division=0, ) if self.average is None: return { diff --git a/tests/library/test_hf_artifacts.py b/tests/library/test_hf_artifacts.py index 9bd1912170..3940b300ba 100644 --- a/tests/library/test_hf_artifacts.py +++ b/tests/library/test_hf_artifacts.py @@ -12,7 +12,8 @@ class HFTests(UnitxtTestCase): def test_dataset_imports(self): missing_imports = get_missing_imports( - unitxt.dataset_file, exclude=["dataset", "evaluate_cli", "__init__", "api"] + unitxt.dataset_file, + exclude=["dataset", "evaluate_cli", "__init__", "api", "evalassist_judge"], ) self.assertEqual(missing_imports, []) @@ -27,6 +28,7 @@ def test_metric_imports(self): "dataset_utils", "evaluate_cli", "api", + "evalassist_judge", ], ) self.assertEqual(missing_imports, [])