Skip to content

Commit 80b2244

Browse files
committed
allow repeats
1 parent f198094 commit 80b2244

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

lm_eval/tasks/math500/math500.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ test_split: test
66
doc_to_text: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.\n\nProblem: {{ problem }}"
77
process_results: !function utils.process_results
88
doc_to_target: "{{answer if few_shot is undefined else solution}}"
9+
repeats: 2
910
generation_kwargs:
1011
until: []
1112
max_gen_toks: 5120
12-
do_sample: false
13-
temperature: 0
13+
do_sample: true
14+
temperature: 0.6
1415
metric_list:
1516
- metric: exact_match
1617
aggregation: mean

lm_eval/tasks/math500/utils.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _process_doc(doc: dict) -> dict:
6363
# ]
6464

6565

66+
# calculate pass@1 for all results
6667
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
6768
candidates = results[0]
6869

@@ -184,18 +185,18 @@ def is_equiv(x1: str, x2: str) -> bool:
184185
return False
185186

186187

187-
def get_unnormalized_answer(text: str) -> str:
188-
INVALID_ANSWER = "[invalidanswer]"
189-
end_seq = "I hope it is correct."
190-
text += end_seq
191-
match = re.search(
192-
r"Final Answer: The final answer is(.*?). I hope it is correct.",
193-
text,
194-
)
195-
if match:
196-
return match.group(1).strip()
197-
else:
198-
return INVALID_ANSWER
188+
# def get_unnormalized_answer(text: str) -> str:
189+
# INVALID_ANSWER = "[invalidanswer]"
190+
# end_seq = "I hope it is correct."
191+
# text += end_seq
192+
# match = re.search(
193+
# r"Final Answer: The final answer is(.*?). I hope it is correct.",
194+
# text,
195+
# )
196+
# if match:
197+
# return match.group(1).strip()
198+
# else:
199+
# return INVALID_ANSWER
199200

200201

201202
SUBSTITUTIONS = [

0 commit comments

Comments
 (0)