Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions lm_eval/tasks/mlmm/multilingual_mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,93 @@
","
)

SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]


def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
"""
return {f"mlmm_mmlu_{lang}": create_task(lang) for lang in LANGS}
tasks = {}

# Language tasks
tasks.update({f"mlmm_mmlu_{lang}": create_task(lang) for lang in LANGS})

# Subject tasks
tasks.update(
{
f"mlmm_mmlu_{lang}_{subject}": create_task(lang, subject)
for lang in LANGS
for subject in SUBJECTS
}
)

return tasks


def create_task(lang):
def create_task(lang, subject=None):
class HendrycksTest(GeneralHendrycksTest):
def __init__(self):
super().__init__(lang)
super().__init__(lang, subject)

return HendrycksTest

Expand All @@ -49,9 +123,10 @@ class GeneralHendrycksTest(MultipleChoiceTask):
NUM_FEW_SHOT = 25
DATASET_NAME = None

def __init__(self, lang):
def __init__(self, lang, subject=None):
self.DATASET_NAME = f"mmlu_{lang}"
self.DATASET_PATH = get_mlmm_dataset_path("datasets/m_mmlu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be self.DATASET_PATH = "malteos/m_mmlu" instead?

self.subject = subject

super().__init__()

Expand All @@ -65,10 +140,21 @@ def has_test_docs(self):
return True

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
return map(
self._process_doc, filter(self._filter_doc, self.dataset["validation"])
)

def test_docs(self):
return map(self._process_doc, self.dataset["test"])
return map(self._process_doc, filter(self._filter_doc, self.dataset["test"]))

def _filter_doc(self, doc):
if self.subject:
# Filter based on subject
subject = doc["id"].split("/", 2)[0]

return subject == self.subject

return True

def _process_doc(self, doc):
def format_example(doc, keys):
Expand All @@ -89,6 +175,7 @@ def format_example(doc, keys):
return prompt

keys = ["A", "B", "C", "D"]

return {
"query": format_example(doc, keys),
"choices": doc["choices"],
Expand Down