diff --git a/hezar/data/datasets/dataset.py b/hezar/data/datasets/dataset.py index 9d47d56a..8def5798 100644 --- a/hezar/data/datasets/dataset.py +++ b/hezar/data/datasets/dataset.py @@ -15,7 +15,7 @@ SplitType, ) from ...preprocessors import Preprocessor, PreprocessorsContainer -from ...utils import verify_dependencies +from ...utils import get_module_config_class, list_repo_files, verify_dependencies class Dataset(TorchDataset): @@ -41,15 +41,23 @@ class Dataset(TorchDataset): def __init__( self, config: DatasetConfig, - split=None, + split: str = "train", preprocessor: str | Preprocessor | PreprocessorsContainer = None, **kwargs, ): verify_dependencies(self, self.required_backends) self.config = config.update(kwargs) - self.data_collator = None self.split = split + self.data = self._load(self.split) self.preprocessor = self.create_preprocessor(preprocessor) + self.data_collator = None + + def _load(self, split): + """ + The internal function to load the dataset files and properties. + By default, this uses the HF `datasets.load_dataset()`. + """ + pass @staticmethod def create_preprocessor(preprocessor: str | Preprocessor | PreprocessorsContainer): @@ -139,17 +147,40 @@ def load( """ split = split or "train" config_filename = config_filename or cls.config_filename + + if ":" in hub_path: + hub_path, hf_dataset_config_name = hub_path.split(":") + kwargs["hf_load_kwargs"] = kwargs.get("hf_load_kwargs", {}) + kwargs["hf_load_kwargs"]["name"] = hf_dataset_config_name + if cache_dir is not None: cls.cache_dir = cache_dir + + has_config = config_filename in list_repo_files(hub_path, repo_type="dataset") + if config is not None: dataset_config = config.update(kwargs) - else: + elif has_config: dataset_config = DatasetConfig.load( hub_path, filename=config_filename, repo_type=RepoType.DATASET, cache_dir=cls.cache_dir, + **kwargs, ) + elif kwargs.get("task", None): + config_cls = get_module_config_class(kwargs["task"], registry_type="dataset") + if config_cls: + dataset_config = config_cls(**kwargs) + else: + raise ValueError(f"Task `{kwargs['task']}` is not valid!") + else: + raise ValueError( + f"The dataset at `{hub_path}` does not have enough info and config to load using Hezar!" + f"\nHint: Either pass the proper `config` to `.load()` or pass in required config parameters as " + f"kwargs in `.load()`, most notably `task`!" + ) + dataset_config.path = hub_path dataset = build_dataset( dataset_config.name, diff --git a/hezar/data/datasets/image_captioning_dataset.py b/hezar/data/datasets/image_captioning_dataset.py index 209cb015..417bb51e 100644 --- a/hezar/data/datasets/image_captioning_dataset.py +++ b/hezar/data/datasets/image_captioning_dataset.py @@ -42,7 +42,6 @@ class ImageCaptioningDataset(Dataset): def __init__(self, config: ImageCaptioningDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs) - self.data = self._load(split) self.image_processor = self.preprocessor.image_processor self.tokenizer = self.preprocessor.tokenizer self.data_collator = ImageCaptioningDataCollator( @@ -61,7 +60,7 @@ def __len__(self): """ return len(self.data) - def _load(self, split=None): + def _load(self, split): """ Load the dataset and clean up invalid samples. @@ -72,7 +71,7 @@ def _load(self, split=None): Dataset: The cleaned dataset. """ - data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) + data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) return data def __getitem__(self, index): diff --git a/hezar/data/datasets/ocr_dataset.py b/hezar/data/datasets/ocr_dataset.py index bbbdcf06..36e3d3bc 100644 --- a/hezar/data/datasets/ocr_dataset.py +++ b/hezar/data/datasets/ocr_dataset.py @@ -80,7 +80,6 @@ class OCRDataset(Dataset): def __init__(self, config: OCRDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs) - self.data = self._load(split) self.image_processor = self.preprocessor.image_processor if self.config.text_split_type == TextSplitType.TOKENIZE: if self.config.tokenizer_path is not None: @@ -113,7 +112,7 @@ def _load(self, split=None): Dataset: The cleaned dataset. """ - data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) + data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) # Cleanup dataset valid_indices = [] invalid_indices = [] diff --git a/hezar/data/datasets/sequence_labeling_dataset.py b/hezar/data/datasets/sequence_labeling_dataset.py index 1690efba..7cb0bc52 100644 --- a/hezar/data/datasets/sequence_labeling_dataset.py +++ b/hezar/data/datasets/sequence_labeling_dataset.py @@ -53,7 +53,6 @@ class SequenceLabelingDataset(Dataset): def __init__(self, config: SequenceLabelingDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config, split=split, preprocessor=preprocessor, **kwargs) - self.dataset = self._load(split) self._extract_labels() self.tokenizer = self.preprocessor.tokenizer self.data_collator = SequenceLabelingDataCollator(self.tokenizer, max_length=self.config.max_length) @@ -69,15 +68,14 @@ def _load(self, split): The whole dataset. """ - # TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here. - dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) - return dataset + data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) + return data def _extract_labels(self): """ Extract label names, ids and build dictionaries. """ - tags_list = self.dataset.features[self.config.tags_field].feature.names + tags_list = self.data.features[self.config.tags_field].feature.names self.id2label = self.config.id2label = {k: str(v) for k, v in dict(enumerate(tags_list)).items()} self.label2id = self.config.label2id = {v: k for k, v in self.id2label.items()} self.num_labels = self.config.num_labels = len(tags_list) @@ -90,7 +88,7 @@ def __len__(self): int: The length of the dataset. """ - return len(self.dataset) + return len(self.data) def _tokenize_and_align(self, tokens, labels): """ @@ -143,6 +141,6 @@ def __getitem__(self, index): dict: The input data. """ - tokens, tags = self.dataset[index].values() + tokens, tags = self.data[index].values() inputs = self._tokenize_and_align(tokens, tags) return inputs diff --git a/hezar/data/datasets/speech_recognition_dataset.py b/hezar/data/datasets/speech_recognition_dataset.py index 4334fd28..6af894a7 100644 --- a/hezar/data/datasets/speech_recognition_dataset.py +++ b/hezar/data/datasets/speech_recognition_dataset.py @@ -36,7 +36,6 @@ class SpeechRecognitionDataset(Dataset): def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config, split, preprocessor=preprocessor, **kwargs) - self.data = self._load(split) self.feature_extractor = self.preprocessor.audio_feature_extractor self.tokenizer = self.preprocessor.tokenizer self.data_collator = SpeechRecognitionDataCollator( @@ -49,7 +48,7 @@ def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, preproces ) def _load(self, split): - data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) + data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate)) return data diff --git a/hezar/data/datasets/text_classification_dataset.py b/hezar/data/datasets/text_classification_dataset.py index 8419efdd..bfd4f32d 100644 --- a/hezar/data/datasets/text_classification_dataset.py +++ b/hezar/data/datasets/text_classification_dataset.py @@ -49,7 +49,6 @@ class TextClassificationDataset(Dataset): def __init__(self, config: TextClassificationDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config, split=split, preprocessor=preprocessor, **kwargs) - self.dataset = self._load(split) self._extract_labels() self.tokenizer = self.preprocessor.tokenizer self.data_collator = TextPaddingDataCollator( @@ -68,15 +67,14 @@ def _load(self, split): The whole dataset. """ - # TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here. - dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) + dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) return dataset def _extract_labels(self): """ Extract label names, ids and build dictionaries. """ - labels_list = self.dataset.features[self.config.label_field].names + labels_list = self.data.features[self.config.label_field].names self.id2label = self.config.id2label = {k: str(v) for k, v in dict(enumerate(labels_list)).items()} self.label2id = self.config.label2id = {v: k for k, v in self.id2label.items()} self.num_labels = self.config.num_labels = len(labels_list) @@ -89,7 +87,7 @@ def __len__(self): int: The length of the dataset. """ - return len(self.dataset) + return len(self.data) def __getitem__(self, index): """ @@ -102,8 +100,8 @@ def __getitem__(self, index): dict: The input data. """ - text = self.dataset[index][self.config.text_field] - label = self.dataset[index][self.config.label_field] + text = self.data[index][self.config.text_field] + label = self.data[index][self.config.label_field] inputs = self.tokenizer( text, return_tensors="pt", diff --git a/hezar/data/datasets/text_summarization_dataset.py b/hezar/data/datasets/text_summarization_dataset.py index 37baa6e7..a1d0a1c8 100644 --- a/hezar/data/datasets/text_summarization_dataset.py +++ b/hezar/data/datasets/text_summarization_dataset.py @@ -53,7 +53,6 @@ class TextSummarizationDataset(Dataset): def __init__(self, config: TextSummarizationDatasetConfig, split=None, preprocessor=None, **kwargs): super().__init__(config, split=split, preprocessor=preprocessor, **kwargs) - self.dataset = self._load(split) self.tokenizer = self.preprocessor.tokenizer self.data_collator = TextGenerationDataCollator( tokenizer=self.tokenizer, @@ -73,9 +72,8 @@ def _load(self, split): The whole dataset. """ - # TODO: In case we want to make this class work on other types like csv, json, etc. we have to do it here. - dataset = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir) - return dataset + data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir, **self.config.hf_load_kwargs) + return data def __len__(self): """ @@ -85,7 +83,7 @@ def __len__(self): int: The length of the dataset. """ - return len(self.dataset) + return len(self.data) def __getitem__(self, index): """ @@ -98,10 +96,10 @@ def __getitem__(self, index): dict: The input data. """ - text = self.dataset[index][self.config.text_field] + text = self.data[index][self.config.text_field] if self.config.prefix is not None: text = self.config.prefix + text # for conditional generation we might need a static prefix - summary = self.dataset[index][self.config.summary_field] + summary = self.data[index][self.config.summary_field] inputs = self.tokenizer( text,