Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
from .aam.all_tasks_registry import TASK_REGISTRY as AAM_TASK_REGISTRY
from .opengptx.all_tasks_registry import TASK_REGISTRY as OGX_TASK_REGISTRY

from .mlmm import multilingual_arc
from .mlmm import multilingual_hellaswag
from .mlmm import multilingual_mmlu
from .mlmm import multilingual_truthfulqa


########################################
# Translation tasks
########################################
Expand Down Expand Up @@ -328,6 +334,11 @@
**tmp_new_pawsx.construct_tasks(),
**tmp_new_xnli.construct_tasks(),
**mgsm.construct_tasks(),
# Multilingual OpenLLM Evaluation
**multilingual_arc.create_all_tasks(),
**multilingual_mmlu.create_all_tasks(),
**multilingual_truthfulqa.create_all_tasks(),
**multilingual_hellaswag.create_all_tasks(),
}

# append the luminous (eg. Aleph-Alpha implemented) tasks to the whole registry
Expand Down
29 changes: 29 additions & 0 deletions lm_eval/tasks/mlmm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Tasks from "Multilingual Large Language Models Evaluation Benchmark"

Source: https://github.com/nlp-uoregon/mlmm-evaluation

This repo contains benchmark datasets and evaluation scripts for Multilingual Large Language Models (LLMs). These datasets can be used to evaluate the models across 26 different languages and encompass three distinct tasks: ARC, HellaSwag, and MMLU. This is released as a part of our [Okapi framework](https://github.com/nlp-uoregon/Okapi) for multilingual instruction-tuned LLMs with reinforcement learning from human feedback.

- [**ARC**](https://allenai.org/data/arc): A dataset with 7,787 genuine grade-school level, multiple-choice science questions, assembled to encourage research in advanced question-answering.
- [**HellaSwag**](https://allenai.org/data/hellaswag): HellaSWAG is a dataset for studying grounded commonsense inference. It consists of 70k multiple choice questions about grounded situations: each question comes from one of two domains *activitynet* or *wikihow* with four answer choices about what might happen next in the scene. The correct answer is the (real) sentence for the next event; the three incorrect answers are adversarially generated and human verified, so as to fool machines but not humans.
- [**MMLU**](https://arxiv.org/pdf/2009.03300.pdf): This dataset contains multiple choice questions derived from diverse fields of knowledge. The test covers subjects in the humanities, social sciences, hard sciences, and other essential areas of learning for certain individuals.

Currently, our datasets support 26 languages: Russian, German, Chinese, French, Spanish, Italian, Dutch, Vietnamese, Indonesian, Arabic, Hungarian, Romanian, Danish, Slovak, Ukrainian, Catalan, Serbian, Croatian, Hindi, Bengali, Tamil, Nepali, Malayalam, Marathi, Telugu, and Kannada.

"""
import os


def get_mlmm_dataset_path(dataset_path: str) -> str:
base_path = os.environ.get("MLMM_DATASET_BASE_PATH", None)

if base_path:
dataset_path = os.path.join(base_path, dataset_path)

if not os.path.exists(dataset_path):
raise FileNotFoundError(
f"Dataset path does not exist ({dataset_path}). If you already downloaded the data, try to set the MLMM_DATASET_BASE_PATH environment variable. To download the data, follow the instruction as provided here: https://github.com/nlp-uoregon/mlmm-evaluation/tree/main#basic-usage"
)

return dataset_path
95 changes: 95 additions & 0 deletions lm_eval/tasks/mlmm/multilingual_arc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf

The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.

Homepage: https://allenai.org/data/arc
"""
from lm_eval.base import MultipleChoiceTask
from . import get_mlmm_dataset_path

_CITATION = """
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
"""

LANGS = "ar,bn,ca,da,de,es,eu,fr,gu,hi,hr,hu,hy,id,it,kn,ml,mr,ne,nl,pt,ro,ru,sk,sr,sv,ta,te,uk,vi,zh".split(
","
)


def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {arc_vi: Task, arc_bn: Task}
"""
return {f"mlmm_arc_{lang}": create_task(lang) for lang in LANGS}


def create_task(lang):
class ATest(MultilingualARC):
def __init__(self):
super().__init__(lang)

return ATest


class MultilingualARC(MultipleChoiceTask):
def __init__(self, lang, **kwargs):
self.VERSION = 0
self.lang = lang
self.DATASET_NAME = f"arc_{lang}"
self.DATASET_PATH = get_mlmm_dataset_path("datasets/m_arc")
self.NUM_FEW_SHOT = 25
super().__init__(**kwargs)

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

def test_docs(self):
return map(self._process_doc, self.dataset["test"])

def _process_doc(self, doc):
# NOTE:
out_doc = {
"id": doc["id"],
"query": "Question: " + doc["question"] + "\nAnswer:",
"choices": doc["choices"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc

def doc_to_text(self, doc):
return doc["query"]

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]
101 changes: 101 additions & 0 deletions lm_eval/tasks/mlmm/multilingual_hellaswag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf
Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.
Homepage: https://rowanzellers.com/hellaswag/
"""
import re
from lm_eval.base import MultipleChoiceTask
from . import get_mlmm_dataset_path

_CITATION = """
@inproceedings{zellers2019hellaswag,
title={HellaSwag: Can a Machine Really Finish Your Sentence?},
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
year={2019}
}
"""

LANGS = "ar,bn,ca,da,de,es,eu,fr,gu,hi,hr,hu,hy,id,it,kn,ml,mr,ne,nl,pt,ro,ru,sk,sr,sv,ta,te,uk,vi,zh".split(
","
)


def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {hellaswag_vi: Task, hellaswag_en: Task}
"""
return {f"mlmm_hellaswag_{lang}": create_task(lang) for lang in LANGS}


def create_task(lang):
class ATest(HellaSwag):
def __init__(self):
super().__init__(lang)

return ATest


class HellaSwag(MultipleChoiceTask):
def __init__(self, lang, **kwargs):
self.VERSION = 1
self.lang = lang
self.DATASET_NAME = f"hellaswag_{lang}"
self.DATASET_PATH = get_mlmm_dataset_path("datasets/m_hellaswag")
self.NUM_FEW_SHOT = 0
super().__init__(**kwargs)

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

def _process_doc(self, doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
out_doc = {
"query": self.preprocess(doc["activity_label"] + ": " + ctx),
"choices": [self.preprocess(ending) for ending in doc["endings"]],
"gold": int(doc["label"]),
}
return out_doc

@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text

def doc_to_text(self, doc):
return doc["query"]

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]
116 changes: 116 additions & 0 deletions lm_eval/tasks/mlmm/multilingual_mmlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Measuring Massive Multitask Language Understanding
https://arxiv.org/pdf/2009.03300.pdf
The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy.
The test covers 57 tasks including elementary mathematics, US history, computer
science, law, and more. To attain high accuracy on this test, models must possess
extensive world knowledge and problem solving ability. By comprehensively evaluating
the breadth and depth of a model’s academic and professional understanding,
Hendryck's Test can be used to analyze models across many tasks and to identify
important shortcomings.
Homepage: https://github.com/hendrycks/test
"""
from lm_eval.base import MultipleChoiceTask
from . import get_mlmm_dataset_path

_CITATION = """
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021}
}
"""
LANGS = "ar,bn,ca,da,de,es,eu,fr,gu,hi,hr,hu,hy,id,it,kn,ml,mr,ne,nl,pt,ro,ru,sk,sr,sv,ta,te,uk,vi,zh".split(
","
)


def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
"""
return {f"mlmm_mmlu_{lang}": create_task(lang) for lang in LANGS}


def create_task(lang):
class HendrycksTest(GeneralHendrycksTest):
def __init__(self):
super().__init__(lang)

return HendrycksTest


class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0
NUM_FEW_SHOT = 25
DATASET_NAME = None

def __init__(self, lang):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MMLU dataset is subdivided into various categories (cf. https://github.com/OpenGPTX/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_test.py), this would be a nice option to have for comparison purposes with the english task. However, the dataset as loaded here is not split by subject, the dataset builder script would have to be modified.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subject-level tasks are added in the latest commit. To achieve that, I uploaded the datasets to HF and updated the builder script to include the subject.

self.DATASET_NAME = f"mmlu_{lang}"
self.DATASET_PATH = get_mlmm_dataset_path("datasets/m_mmlu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be self.DATASET_PATH = "malteos/m_mmlu" instead?


super().__init__()

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strangely, the dataset builder loads {lang}_dev.json as the validation split instead of {lang}_val.json

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be easily changed but I could keep it like this for consistency.


def test_docs(self):
return map(self._process_doc, self.dataset["test"])

def _process_doc(self, doc):
def format_example(doc, keys):
"""
Question: <prompt>
Choices:
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
)
prompt += "Answer:"
return prompt

keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else doc["answer"],
}

def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't

if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))

return rnd.sample(list(self._fewshot_docs), k)

def doc_to_text(self, doc):
return doc["query"]

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]
Loading