@@ -49,7 +49,7 @@ def load_and_filter_dataset(
4949 return dataset .iloc [start :end ] if end > 0 else dataset .iloc [start :]
5050
5151
52- class MMLUProTaskHandler (MMLUTaskHandler ):
52+ class MMLUProTaskHandler (TaskHandler ):
5353 def __init__ (self , task_config : TaskConfig ):
5454 super ().__init__ (task_config )
5555 self .choices = [
@@ -71,9 +71,27 @@ def __init__(self, task_config: TaskConfig):
7171 "P" ,
7272 ]
7373
74- def generate_prompt (self , prompt ):
74+ def generate_prompt (self , problem ):
75+ multiple_choice_string = self .get_multiple_choice_answers (problem )
76+ prompt = problem ["question" ] + "\n " + multiple_choice_string
7577 return self .task_config .templating_parameters ["template" ].format (prompt = prompt )
7678
79+ def update_results (self , problem , response ):
80+ # Initialize the response structure
81+ response_entry = {
82+ "content" : response ,
83+ "correctness" : None ,
84+ "reason" : None ,
85+ }
86+ curr_res = self .check_correctness (problem , generation = response )
87+ if curr_res :
88+ response_entry ["correctness" ] = True
89+ response_entry ["reason" ] = ""
90+ else :
91+ response_entry ["correctness" ] = False
92+ response_entry ["reason" ] = "Solution is incorrect."
93+ return response_entry
94+
7795 def check_correctness (self , problem , generation ):
7896 pred = mmlu_pro_extract_answer (generation )
7997 answer = self .choices [problem ["answer_index" ]]
0 commit comments