From 6bc8b8cc54360e97f8de0a6e9bdd681c5055fbbe Mon Sep 17 00:00:00 2001 From: Baber Date: Tue, 21 Jan 2025 23:37:50 +0000 Subject: [PATCH] add math500 pass@1 --- lm_eval/tasks/math500/math500.yaml | 16 +++++++++-- lm_eval/tasks/math500/utils.py | 46 +++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/lm_eval/tasks/math500/math500.yaml b/lm_eval/tasks/math500/math500.yaml index 10e636ba4c..68152777ca 100644 --- a/lm_eval/tasks/math500/math500.yaml +++ b/lm_eval/tasks/math500/math500.yaml @@ -4,16 +4,25 @@ process_docs: !function utils.process_docs output_type: generate_until test_split: test doc_to_text: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.\n\nProblem: {{ problem }}" -process_results: !function utils.process_results +#process_results: !function utils.process_results doc_to_target: "{{answer if few_shot is undefined else solution}}" +process_results: !function utils.process_results repeats: 2 generation_kwargs: until: [] - max_gen_toks: 5120 + max_gen_toks: 1024 do_sample: true + top_p: 0.95 temperature: 0.6 +filter_list: + - name: "pass@1" + filter: + - function: "custom" + filter_fn: !function utils.filter_final_answer + - function: "custom" + filter_fn: !function utils.get_metric metric_list: - - metric: exact_match + - metric: acc aggregation: mean higher_is_better: true num_fewshot: 0 @@ -21,3 +30,4 @@ metadata: version: 1.0 dataset_kwargs: trust_remote_code: true + diff --git a/lm_eval/tasks/math500/utils.py b/lm_eval/tasks/math500/utils.py index 8f9770f427..0925adee12 100644 --- a/lm_eval/tasks/math500/utils.py +++ b/lm_eval/tasks/math500/utils.py @@ -63,21 +63,39 @@ def _process_doc(doc: dict) -> dict: # ] -# calculate pass@1 for all results -def process_results(doc: dict, results: List[str]) -> Dict[str, int]: - candidates = results[0] - - answer = normalize_final_answer(remove_boxed(last_boxed_only_string(candidates))) - - if is_equiv(answer, doc["answer"]): - retval = 1 - else: - retval = 0 +def filter_final_answer(resps: list[list[str]], docs) -> list[list[str]]: + answer = [] + for resp in resps: + answer.append( + [ + normalize_final_answer(remove_boxed(last_boxed_only_string(r[0]))) + for r in resp + ] + ) + return answer + +def process_results(docs: dict, resps: list[dict]) -> dict: + return resps[0] - results = { - "exact_match": retval, - } - return results +# calculate pass@1 for all results +def get_metric(predictions: list[list[str]], references: list[dict]) -> Dict[str, int]: + res = [] + for reference, candidates in zip(references, predictions): + for candidate in candidates: + answer = normalize_final_answer( + remove_boxed(last_boxed_only_string(candidate)) + ) + if is_equiv(answer, reference["answer"]): + retval = 1 + + results = { + "accuracy": retval, + } + res.append(results) + break + else: + res.append({"accuracy": 0}) + return res def last_boxed_only_string(string: str) -> Optional[str]: