Skip to content

Commit

Permalink
add math500 pass@1
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed Jan 21, 2025
1 parent 80b2244 commit 6bc8b8c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
16 changes: 13 additions & 3 deletions lm_eval/tasks/math500/math500.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@ process_docs: !function utils.process_docs
output_type: generate_until
test_split: test
doc_to_text: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.\n\nProblem: {{ problem }}"
process_results: !function utils.process_results
#process_results: !function utils.process_results
doc_to_target: "{{answer if few_shot is undefined else solution}}"
process_results: !function utils.process_results
repeats: 2
generation_kwargs:
until: []
max_gen_toks: 5120
max_gen_toks: 1024
do_sample: true
top_p: 0.95
temperature: 0.6
filter_list:
- name: "pass@1"
filter:
- function: "custom"
filter_fn: !function utils.filter_final_answer
- function: "custom"
filter_fn: !function utils.get_metric
metric_list:
- metric: exact_match
- metric: acc
aggregation: mean
higher_is_better: true
num_fewshot: 0
metadata:
version: 1.0
dataset_kwargs:
trust_remote_code: true

46 changes: 32 additions & 14 deletions lm_eval/tasks/math500/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,39 @@ def _process_doc(doc: dict) -> dict:
# ]


# calculate pass@1 for all results
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
candidates = results[0]

answer = normalize_final_answer(remove_boxed(last_boxed_only_string(candidates)))

if is_equiv(answer, doc["answer"]):
retval = 1
else:
retval = 0
def filter_final_answer(resps: list[list[str]], docs) -> list[list[str]]:
answer = []
for resp in resps:
answer.append(
[
normalize_final_answer(remove_boxed(last_boxed_only_string(r[0])))
for r in resp
]
)
return answer

def process_results(docs: dict, resps: list[dict]) -> dict:
return resps[0]

results = {
"exact_match": retval,
}
return results
# calculate pass@1 for all results
def get_metric(predictions: list[list[str]], references: list[dict]) -> Dict[str, int]:
res = []
for reference, candidates in zip(references, predictions):
for candidate in candidates:
answer = normalize_final_answer(
remove_boxed(last_boxed_only_string(candidate))
)
if is_equiv(answer, reference["answer"]):
retval = 1

results = {
"accuracy": retval,
}
res.append(results)
break
else:
res.append({"accuracy": 0})
return res


def last_boxed_only_string(string: str) -> Optional[str]:
Expand Down

0 comments on commit 6bc8b8c

Please sign in to comment.