Skip to content

Commit fd67715

Browse files
committed
fix mmlu; subclassing is a mistake
Signed-off-by: SumanthRH <[email protected]>
1 parent 5fb4f2b commit fd67715

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

skythought/evals/tasks/mmlu/mmlu_handler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def load_and_filter_dataset(
4949
return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
5050

5151

52-
class MMLUProTaskHandler(MMLUTaskHandler):
52+
class MMLUProTaskHandler(TaskHandler):
5353
def __init__(self, task_config: TaskConfig):
5454
super().__init__(task_config)
5555
self.choices = [
@@ -71,9 +71,27 @@ def __init__(self, task_config: TaskConfig):
7171
"P",
7272
]
7373

74-
def generate_prompt(self, prompt):
74+
def generate_prompt(self, problem):
75+
multiple_choice_string = self.get_multiple_choice_answers(problem)
76+
prompt = problem["question"] + "\n" + multiple_choice_string
7577
return self.task_config.templating_parameters["template"].format(prompt=prompt)
7678

79+
def update_results(self, problem, response):
80+
# Initialize the response structure
81+
response_entry = {
82+
"content": response,
83+
"correctness": None,
84+
"reason": None,
85+
}
86+
curr_res = self.check_correctness(problem, generation=response)
87+
if curr_res:
88+
response_entry["correctness"] = True
89+
response_entry["reason"] = ""
90+
else:
91+
response_entry["correctness"] = False
92+
response_entry["reason"] = "Solution is incorrect."
93+
return response_entry
94+
7795
def check_correctness(self, problem, generation):
7896
pred = mmlu_pro_extract_answer(generation)
7997
answer = self.choices[problem["answer_index"]]

0 commit comments

Comments
 (0)