Skip to content

Commit b5ed97f

Browse files
New Features (#14)
- RecallEM metric. - Aggregation steps: filtering, column selection, tagging, value overwrite. - Local inference step using vLLM; can generate synthetic datasets. - Some minor modification of the QA system instructions. - Ruff configuration file. - Evaluation split in the training script.
1 parent f21cd32 commit b5ed97f

File tree

17 files changed

+272
-20
lines changed

17 files changed

+272
-20
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
/.python-version
22
/outputs/
33
__pycache__/
4-
/site/
4+
/site/
5+
/multirun/
6+
wandb
7+
.ipynb_checkpoints

configs/processing-nq.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: nq
2+
cache: false
3+
output_path: .
4+
steps:
5+
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
6+
inputs: train
7+
dataset_config:
8+
path: Tevatron/wikipedia-nq
9+
split: train
10+
11+
- _target_: ragfit.processing.global_steps.sampling.ShuffleSelect
12+
inputs: train
13+
shuffle: 42
14+
limit: 10000
15+
16+
- _target_: ragfit.processing.local_steps.prompter.TextPrompter
17+
inputs: train
18+
prompt_file: ragfit/processing/prompts/qa-short.txt
19+
output_key: prompt
20+
mapping:
21+
query: query
22+
23+
- _target_: ragfit.processing.local_steps.inference.HFStep
24+
inputs: train
25+
input_key: prompt
26+
output_key: generated
27+
model_kwargs:
28+
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
29+
instruction: ragfit/processing/prompts/prompt_instructions/qa-short.txt
30+
num_gpus: 2
31+
llm_params:
32+
dtype: auto
33+
max_model_len: 4096
34+
generation:
35+
temperature: 0
36+
max_tokens: 50
37+
38+
- _target_: ragfit.processing.global_steps.output.OutputData
39+
inputs: train
40+
prefix: nq-with-answers
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
::: ragfit.processing.global_steps.filters
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
::: ragfit.processing.local_steps.inference

evaluation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
from collections import defaultdict
4+
from pathlib import Path
45

56
import hydra
67
import torch
@@ -93,10 +94,12 @@ def map_load(example, idx):
9394
if args.use_wandb:
9495
run.log(results, step=0)
9596

96-
if args.results_file:
97-
with open(args.results_file, "w") as f:
98-
yaml.dump(results, f, sort_keys=True)
99-
logging.info(f"Results saved to {args.results_file}")
97+
if args.results_file is None:
98+
args.results_file = Path(args.generated_file).stem + "-results.yaml"
99+
100+
with open(args.results_file, "w") as f:
101+
yaml.dump(results, f, sort_keys=True)
102+
logging.info(f"Results saved to {args.results_file}")
100103

101104

102105
if __name__ == "__main__":

pyproject.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,3 @@ haystack = [
4040
"qdrant-haystack>=5.0.0",
4141
]
4242

43-
[tool.ruff]
44-
line-length = 90
45-
46-
[tool.ruff.lint]
47-
select = ["E", "F", "W", "I", "N", "Q"]
48-
ignore = ["E203", "F841", "E501"]

ragfit/evaluation/metrics.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import re
22
import string
3+
import unicodedata
34
from collections import Counter, defaultdict
45

6+
import regex
7+
58
from .base import MetricBase
69

710

@@ -71,16 +74,21 @@ def __init__(
7174
self.precision_recall_fn = precision_recall_fscore_support
7275
self.accuracy_fn = accuracy_score
7376

77+
def in_text(self, text):
78+
if "yes" in text:
79+
return 1
80+
if "no" in text:
81+
return 0
82+
return 2
83+
7484
def measure(self, example: dict):
7585
inputs = example[self.field]
7686
targets = example[self.target]
7787

7888
if isinstance(targets[0], list):
7989
targets = [t[0] for t in targets]
8090

81-
inputs = [
82-
self.mapping.get(normalize_text(i).strip(), self.else_value) for i in inputs
83-
]
91+
inputs = [self.in_text(normalize_text(i).strip()) for i in inputs]
8492

8593
targets = [
8694
self.mapping.get(normalize_text(t).strip(), self.else_value) for t in targets
@@ -222,6 +230,73 @@ def measure(self, example: dict):
222230
return {"StringEM": sum(scores) / len(scores)}
223231

224232

233+
class SimpleTokenizer(object):
234+
ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+"
235+
NON_WS = r"[^\p{Z}\p{C}]"
236+
237+
def __init__(self):
238+
"""
239+
Args:
240+
annotators: None or empty set (only tokenizes).
241+
"""
242+
self._regexp = regex.compile(
243+
"(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS),
244+
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE,
245+
)
246+
247+
def tokenize(self, text, uncased=False):
248+
matches = [m for m in self._regexp.finditer(text)]
249+
if uncased:
250+
tokens = [m.group().lower() for m in matches]
251+
else:
252+
tokens = [m.group() for m in matches]
253+
return tokens
254+
255+
256+
class RecallEM(MetricBase):
257+
"""
258+
Implementing EM as in XRAG.
259+
"""
260+
261+
def __init__(self, key_names, **kwargs) -> None:
262+
"""Initialize the Metrics class.
263+
264+
Args:
265+
key_names (dict): A dictionary containing the field names.
266+
"""
267+
super().__init__(key_names, **kwargs)
268+
self.local = True
269+
270+
@staticmethod
271+
def _normalize(text):
272+
return unicodedata.normalize("NFD", text)
273+
274+
def has_answer(self, answers, text, tokenizer=SimpleTokenizer()):
275+
"""Check if a document contains an answer string."""
276+
text = self._normalize(text)
277+
text = tokenizer.tokenize(text, uncased=True)
278+
279+
for answer in answers:
280+
answer = self._normalize(answer)
281+
answer = tokenizer.tokenize(answer, uncased=True)
282+
for i in range(0, len(text) - len(answer) + 1):
283+
if answer == text[i : i + len(answer)]:
284+
return True
285+
return False
286+
287+
def measure(self, example: dict):
288+
input = example[self.field]
289+
target = example[self.target]
290+
291+
assert isinstance(input, str), f"Generated text should be a string: {input}"
292+
293+
if not isinstance(target, list):
294+
target = [target]
295+
296+
scores = self.has_answer(target, input)
297+
return {"recallEM": int(scores)}
298+
299+
225300
class BERTScore(MetricBase):
226301
"""
227302
BERTScore metric, based on the BERTScore library.

ragfit/processing/global_steps/aggregation.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,42 @@
11
from datasets import concatenate_datasets
22

33
from ..step import GlobalStep
4+
from .filters import filters
5+
6+
7+
class FilterDataset(GlobalStep):
8+
"""
9+
Step for filtering a dataset.
10+
"""
11+
12+
def __init__(self, filter_fn, **kwargs):
13+
"""
14+
Args:
15+
filter_fn (function): Function to filter the dataset.
16+
"""
17+
super().__init__(**kwargs)
18+
self.filter_fn = filters[filter_fn]
19+
20+
def process(self, dataset_name, datasets, **kwargs):
21+
datasets[dataset_name] = datasets[dataset_name].filter(self.filter_fn)
22+
23+
24+
class SelectColumns(GlobalStep):
25+
"""
26+
Step for selecting specified columns in a dataset.
27+
"""
28+
29+
def __init__(self, columns: list[str], **kwargs):
30+
"""
31+
Args:
32+
columns (list): List of keys to keep in the dataset.
33+
"""
34+
super().__init__(**kwargs)
35+
assert isinstance(columns, list), "columns should be a list of strings."
36+
self.columns = columns
37+
38+
def process(self, dataset_name, datasets, **kwargs):
39+
datasets[dataset_name] = datasets[dataset_name].select_columns(self.columns)
440

541

642
class MergeDatasets(GlobalStep):
@@ -29,3 +65,27 @@ def process(self, dataset_name, datasets, **kwargs):
2965
data = data.shuffle(self.shuffle)
3066
datasets[self.output] = data
3167
self.completed = True
68+
69+
70+
class DatasetTagger(GlobalStep):
71+
"""
72+
Class to tag each example with the dataset name. Useful when running aggregations.
73+
"""
74+
75+
def __init__(self, keyword="source", **kwargs):
76+
"""
77+
Args:
78+
keyword (str): The key to use for tagging. Default is "source".
79+
"""
80+
super().__init__(**kwargs)
81+
self.keyword = keyword
82+
83+
def tag(self, item, dataset_name):
84+
item[self.keyword] = dataset_name
85+
return item
86+
87+
def process(self, dataset_name, datasets, **kwargs):
88+
datasets[dataset_name] = datasets[dataset_name].map(
89+
lambda item: self.tag(item, dataset_name),
90+
load_from_cache_file=False,
91+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Module containing filters"""
2+
3+
4+
def msmarco_positive_filter(x):
5+
return 1 in x["passages"]["is_selected"]
6+
7+
8+
filters = {"MSMARCO": msmarco_positive_filter}

ragfit/processing/global_steps/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def process_all(self, dataset, datasets, **kwargs):
2424
if self.shuffle:
2525
dataset = dataset.shuffle(seed=self.shuffle)
2626
if self.limit:
27-
dataset = dataset.select(range(self.limit))
27+
dataset = dataset.select(range(min(len(dataset), self.limit)))
2828
return dataset
2929

3030

ragfit/processing/local_steps/formatting.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,23 @@ def __init__(self, input_key, output_key, string_join=", ", **kwargs):
4040
def process_item(self, item, index, datasets, **kwargs):
4141
item[self.output_key] = self.string_join.join(item[self.input_key])
4242
return item
43+
44+
45+
class UpdateField(LocalStep):
46+
"""
47+
Class to update a field in the dataset with a new value.
48+
"""
49+
50+
def __init__(self, input_key: str, value, **kwargs):
51+
"""
52+
Args:
53+
input_key (str): example key to change.
54+
value: New value to set for the field.
55+
"""
56+
super().__init__(**kwargs)
57+
self.input_key = input_key
58+
self.value = value
59+
60+
def process_item(self, item, index, datasets, **kwargs):
61+
item[self.input_key] = self.value
62+
return item
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Module for inference steps, which can use LLM output to augment the data."""
2+
3+
from ragfit.models.vllm import VLLMInference
4+
5+
from ..step import LocalStep
6+
7+
8+
class HFStep(LocalStep):
9+
"""
10+
Class for running inference with a Hugging Face model based on the vLLM engine.
11+
"""
12+
13+
def __init__(self, input_key, output_key, model_kwargs, **kwargs):
14+
"""
15+
Initialize the HFStep class.
16+
17+
Args:
18+
input_key (str): The key for the input text to be served as the prompt.
19+
output_key (str): The key for for saving the generated text.
20+
model_kwargs (dict): The keyword arguments to pass to the vLLM model.
21+
**kwargs: Additional keyword arguments to pass to the LocalStep.
22+
"""
23+
super().__init__(**kwargs)
24+
self.input_key = input_key
25+
self.output_key = output_key
26+
self.model_kwargs = model_kwargs
27+
self.model = VLLMInference(**model_kwargs)
28+
29+
def process_item(self, item, index, datasets, **kwargs):
30+
prompt = item[self.input_key]
31+
response = self.model.generate(prompt)
32+
item[self.output_key] = response
33+
return item

ragfit/processing/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def gen_cache_fn(self, step, index, dataset_name):
5757
Returns a string.
5858
"""
5959
return (
60-
f"{self.output_path}/{self.name}"
61-
f"_{index}_{type(step).__name__}"
60+
f"{self.output_path}/cache"
61+
f"_{self.name}_{index}"
62+
f"_{type(step).__name__}"
6263
f"_{dataset_name}_{step.get_hash()}.json"
6364
)
6465

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer shortly as possible and don't repeat the question.
1+
You are a helpful question answerer who can provide an answer given a question and relevant context. Answer the following question with a short span. The answer needs to be just in a few words.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer with "yes", "no" or "maybe", if there is not enough information to answer the question.

ruff.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
line-length = 90
2+
3+
[lint]
4+
select = ["E", "F", "W", "I", "N", "Q"]
5+
ignore = ["E203", "F841", "E501", "F821"]
6+
exclude = ["*.ipynb"]

0 commit comments

Comments
 (0)