Skip to content

Commit 6534b8a

Browse files
committed
✨ Add hf_load_kwargs to DatasetConfig
1 parent edb20e6 commit 6534b8a

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

hezar/configs.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,37 @@ class PreprocessorConfig(Config):
322322
class DatasetConfig(Config):
323323
"""
324324
Base dataclass for all dataset configs
325+
326+
Args:
327+
path (str):
328+
Path to the dataset either on the Hub or local. Supported syntax is either `<path>` or `<path>:<name>` where
329+
<name> is the parameter `name` in the `load_dataset()`
330+
task (str):
331+
A supported task for the dataset
332+
hf_load_kwargs (dict):
333+
keyword arguments to pass to the HF `datasets.load_dataset()`
325334
"""
326335

327336
name: str = field(init=False, default=None)
328337
config_type: str = field(init=False, default=ConfigType.DATASET)
338+
path: str = None
329339
task: TaskType | List[TaskType] = field(
330-
default=None, metadata={"help": "Name of the task(s) this dataset is built for"}
340+
default=None,
341+
metadata={"help": "Name of the task(s) this dataset is built for"}
331342
)
332-
path: str = None
343+
hf_load_kwargs: dict = None
344+
345+
def __post_init__(self):
346+
super().__post_init__()
347+
if self.path and ":" in self.path:
348+
self.path, config_name = self.path.split(":")
349+
self.hf_load_kwargs["name"] = config_name
350+
if self.hf_load_kwargs:
351+
self.hf_load_kwargs.pop("path", None)
352+
self.hf_load_kwargs.pop("cache_dir", None)
353+
self.hf_load_kwargs.pop("split", None)
354+
else:
355+
self.hf_load_kwargs = {}
333356

334357

335358
@dataclass

0 commit comments

Comments
 (0)