Skip to content

Commit

Permalink
✨ Make _load() method abstract for all dataset classes
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Jun 11, 2024
1 parent 6534b8a commit 18ff1b4
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 32 deletions.
39 changes: 35 additions & 4 deletions hezar/data/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SplitType,
)
from ...preprocessors import Preprocessor, PreprocessorsContainer
from ...utils import verify_dependencies
from ...utils import get_module_config_class, list_repo_files, verify_dependencies


class Dataset(TorchDataset):
Expand All @@ -41,15 +41,23 @@ class Dataset(TorchDataset):
def __init__(
self,
config: DatasetConfig,
split=None,
split: str = "train",
preprocessor: str | Preprocessor | PreprocessorsContainer = None,
**kwargs,
):
verify_dependencies(self, self.required_backends)
self.config = config.update(kwargs)
self.data_collator = None
self.split = split
self.data = self._load(self.split)
self.preprocessor = self.create_preprocessor(preprocessor)
self.data_collator = None

def _load(self, split):
"""
The internal function to load the dataset files and properties.
By default, this uses the HF `datasets.load_dataset()`.
"""
pass

@staticmethod
def create_preprocessor(preprocessor: str | Preprocessor | PreprocessorsContainer):
Expand Down Expand Up @@ -139,17 +147,40 @@ def load(
"""
split = split or "train"
config_filename = config_filename or cls.config_filename

if ":" in hub_path:
hub_path, hf_dataset_config_name = hub_path.split(":")
kwargs["hf_load_kwargs"] = kwargs.get("hf_load_kwargs", {})
kwargs["hf_load_kwargs"]["name"] = hf_dataset_config_name

if cache_dir is not None:
cls.cache_dir = cache_dir

has_config = config_filename in list_repo_files(hub_path, repo_type="dataset")

if config is not None:
dataset_config = config.update(kwargs)
else:
elif has_config:
dataset_config = DatasetConfig.load(
hub_path,
filename=config_filename,
repo_type=RepoType.DATASET,
cache_dir=cls.cache_dir,
**kwargs,
)
elif kwargs.get("task", None):
config_cls = get_module_config_class(kwargs["task"], registry_type="dataset")
if config_cls:
dataset_config = config_cls(**kwargs)
else:
raise ValueError(f"Task `{kwargs['task']}` is not valid!")
else:
raise ValueError(
f"The dataset at `{hub_path}` does not have enough info and config to load using Hezar!"
f"\nHint: Either pass the proper `config` to `.load()` or pass in required config parameters as "
f"kwargs in `.load()`, most notably `task`!"
)

dataset_config.path = hub_path
dataset = build_dataset(
dataset_config.name,
Expand Down
5 changes: 2 additions & 3 deletions hezar/data/datasets/image_captioning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class ImageCaptioningDataset(Dataset):

def __init__(self, config: ImageCaptioningDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs)
self.data = self._load(split)
self.image_processor = self.preprocessor.image_processor
self.tokenizer = self.preprocessor.tokenizer
self.data_collator = ImageCaptioningDataCollator(
Expand All @@ -61,7 +60,7 @@ def __len__(self):
"""
return len(self.data)

def _load(self, split=None):
def _load(self, split):
"""
Load the dataset and clean up invalid samples.
Expand All @@ -72,7 +71,7 @@ def _load(self, split=None):
Dataset: The cleaned dataset.
"""
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
return data

def __getitem__(self, index):
Expand Down
3 changes: 1 addition & 2 deletions hezar/data/datasets/ocr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class OCRDataset(Dataset):

def __init__(self, config: OCRDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs)
self.data = self._load(split)
self.image_processor = self.preprocessor.image_processor
if self.config.text_split_type == TextSplitType.TOKENIZE:
if self.config.tokenizer_path is not None:
Expand Down Expand Up @@ -113,7 +112,7 @@ def _load(self, split=None):
Dataset: The cleaned dataset.
"""
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
# Cleanup dataset
valid_indices = []
invalid_indices = []
Expand Down
12 changes: 5 additions & 7 deletions hezar/data/datasets/sequence_labeling_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class SequenceLabelingDataset(Dataset):

def __init__(self, config: SequenceLabelingDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config, split=split, preprocessor=preprocessor, **kwargs)
self.dataset = self._load(split)
self._extract_labels()
self.tokenizer = self.preprocessor.tokenizer
self.data_collator = SequenceLabelingDataCollator(self.tokenizer, max_length=self.config.max_length)
Expand All @@ -69,15 +68,14 @@ def _load(self, split):
The whole dataset.
"""
# TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here.
dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
return dataset
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
return data

def _extract_labels(self):
"""
Extract label names, ids and build dictionaries.
"""
tags_list = self.dataset.features[self.config.tags_field].feature.names
tags_list = self.data.features[self.config.tags_field].feature.names
self.id2label = self.config.id2label = {k: str(v) for k, v in dict(enumerate(tags_list)).items()}
self.label2id = self.config.label2id = {v: k for k, v in self.id2label.items()}
self.num_labels = self.config.num_labels = len(tags_list)
Expand All @@ -90,7 +88,7 @@ def __len__(self):
int: The length of the dataset.
"""
return len(self.dataset)
return len(self.data)

def _tokenize_and_align(self, tokens, labels):
"""
Expand Down Expand Up @@ -143,6 +141,6 @@ def __getitem__(self, index):
dict: The input data.
"""
tokens, tags = self.dataset[index].values()
tokens, tags = self.data[index].values()
inputs = self._tokenize_and_align(tokens, tags)
return inputs
3 changes: 1 addition & 2 deletions hezar/data/datasets/speech_recognition_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class SpeechRecognitionDataset(Dataset):

def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config, split, preprocessor=preprocessor, **kwargs)
self.data = self._load(split)
self.feature_extractor = self.preprocessor.audio_feature_extractor
self.tokenizer = self.preprocessor.tokenizer
self.data_collator = SpeechRecognitionDataCollator(
Expand All @@ -49,7 +48,7 @@ def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, preproces
)

def _load(self, split):
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate))
return data

Expand Down
12 changes: 5 additions & 7 deletions hezar/data/datasets/text_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class TextClassificationDataset(Dataset):

def __init__(self, config: TextClassificationDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config, split=split, preprocessor=preprocessor, **kwargs)
self.dataset = self._load(split)
self._extract_labels()
self.tokenizer = self.preprocessor.tokenizer
self.data_collator = TextPaddingDataCollator(
Expand All @@ -68,15 +67,14 @@ def _load(self, split):
The whole dataset.
"""
# TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here.
dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
return dataset

def _extract_labels(self):
"""
Extract label names, ids and build dictionaries.
"""
labels_list = self.dataset.features[self.config.label_field].names
labels_list = self.data.features[self.config.label_field].names
self.id2label = self.config.id2label = {k: str(v) for k, v in dict(enumerate(labels_list)).items()}
self.label2id = self.config.label2id = {v: k for k, v in self.id2label.items()}
self.num_labels = self.config.num_labels = len(labels_list)
Expand All @@ -89,7 +87,7 @@ def __len__(self):
int: The length of the dataset.
"""
return len(self.dataset)
return len(self.data)

def __getitem__(self, index):
"""
Expand All @@ -102,8 +100,8 @@ def __getitem__(self, index):
dict: The input data.
"""
text = self.dataset[index][self.config.text_field]
label = self.dataset[index][self.config.label_field]
text = self.data[index][self.config.text_field]
label = self.data[index][self.config.label_field]
inputs = self.tokenizer(
text,
return_tensors="pt",
Expand Down
12 changes: 5 additions & 7 deletions hezar/data/datasets/text_summarization_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class TextSummarizationDataset(Dataset):

def __init__(self, config: TextSummarizationDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config, split=split, preprocessor=preprocessor, **kwargs)
self.dataset = self._load(split)
self.tokenizer = self.preprocessor.tokenizer
self.data_collator = TextGenerationDataCollator(
tokenizer=self.tokenizer,
Expand All @@ -73,9 +72,8 @@ def _load(self, split):
The whole dataset.
"""
# TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here.
dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
return dataset
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs)
return data

def __len__(self):
"""
Expand All @@ -85,7 +83,7 @@ def __len__(self):
int: The length of the dataset.
"""
return len(self.dataset)
return len(self.data)

def __getitem__(self, index):
"""
Expand All @@ -98,10 +96,10 @@ def __getitem__(self, index):
dict: The input data.
"""
text = self.dataset[index][self.config.text_field]
text = self.data[index][self.config.text_field]
if self.config.prefix is not None:
text = self.config.prefix + text # for conditional generation we might need a static prefix
summary = self.dataset[index][self.config.summary_field]
summary = self.data[index][self.config.summary_field]

inputs = self.tokenizer(
text,
Expand Down

0 comments on commit 18ff1b4

Please sign in to comment.