From 1b9969db0f59695f6b3638300b396e1ef72b9d54 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 18 Oct 2024 18:39:01 +0330 Subject: [PATCH 01/33] :sparkles: Add `dataset_processors` and `DatasetProcessor` base class --- hezar/data/__init__.py | 1 + .../dataset_processors/dataset_processor.py | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 hezar/data/dataset_processors/dataset_processor.py diff --git a/hezar/data/__init__.py b/hezar/data/__init__.py index af95c156..93d4354a 100644 --- a/hezar/data/__init__.py +++ b/hezar/data/__init__.py @@ -1,5 +1,6 @@ from ..registry import datasets_registry # noqa from ..builders import build_dataset # noqa +from .dataset_processors import * from .data_collators import * from .data_samplers import * from .datasets import * diff --git a/hezar/data/dataset_processors/dataset_processor.py b/hezar/data/dataset_processors/dataset_processor.py new file mode 100644 index 00000000..57d6f924 --- /dev/null +++ b/hezar/data/dataset_processors/dataset_processor.py @@ -0,0 +1,54 @@ +""" +Dataset processors are a bunch of callable classes to be passed to be used as map functions for any dataset on the Hub. +Note that the main dataset classes are already implemented in a way that the processing is done in the `__getitem__` +method and these classes are only used for when the dataset has been loaded using HuggingFace datasets library. + +Example: +>>> from datasets import load_dataset +>>> from hezar.data import SpeechRecognitionDatasetProcessor + +>>> data_processor = SpeechRecognitionDatasetProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) +>>> dataset = load_dataset("hezarai/common-voice-13-fa") +>>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) + +""" + + +class DatasetProcessor: + def __init__(self, batched=False, *args, **kwargs): + """ + Base constructor that accepts a `batched` flag and any other arguments for child class initialization. + + Args: + batched (bool): Whether to process data in batches or not. + """ + self.batched = batched + self.args = args + self.kwargs = kwargs + + def __call__(self, examples, **fn_kwargs): + """ + Method called when using the map function. + Decides whether to call `process()` or `batch_process()` based on the `batched` flag. + + Args: + examples (dict or list of dict): Data to process. + **fn_kwargs: Additional keyword arguments passed through the `map` function as `fn_kwargs`. + For example, `fn_kwargs` can contain custom settings like `sampling_rate`. + """ + if self.batched: + return self.batch_process(examples, **fn_kwargs) + else: + return self.process(examples, **fn_kwargs) + + def process(self, example, **kwargs): + """ + Method to process a single example + """ + raise NotImplementedError + + def batch_process(self, examples, **kwargs): + """ + Method to process a batch of examples. + """ + raise NotImplementedError From 7980c1e5e89cc09c672fd5859ddd991ea13c679a Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 18 Oct 2024 18:39:26 +0330 Subject: [PATCH 02/33] :sparkles: Add `TextClassificationDatasetProcessor` to dataset processors --- hezar/data/dataset_processors/__init__.py | 2 + .../image_captioning_processor.py | 71 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 hezar/data/dataset_processors/__init__.py create mode 100644 hezar/data/dataset_processors/image_captioning_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py new file mode 100644 index 00000000..f40f7425 --- /dev/null +++ b/hezar/data/dataset_processors/__init__.py @@ -0,0 +1,2 @@ +from .dataset_processor import DatasetProcessor +from .text_classification_processor import TextClassificationDatasetProcessor diff --git a/hezar/data/dataset_processors/image_captioning_processor.py b/hezar/data/dataset_processors/image_captioning_processor.py new file mode 100644 index 00000000..a8f3f9fa --- /dev/null +++ b/hezar/data/dataset_processors/image_captioning_processor.py @@ -0,0 +1,71 @@ +from ...utils import shift_tokens_right +from .dataset_processor import DatasetProcessor + + +class ImageCaptioningDatasetProcessor(DatasetProcessor): + def __init__(self, image_processor, tokenizer, batched=False, max_length=None, padding=None): + super().__init__(batched=batched) + self.image_processor = image_processor + self.tokenizer = tokenizer + self.max_length = max_length + self.padding = padding + + def process(self, example, padding=None, max_length=None): + """ + Process image and tokenize captions for a single data sample. + + Args: + example: A data example containing the image and its caption + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + + Returns: + A dict of pixel values tensor of the processed image and labels token ids and attention mask. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + path, text = example["path"], example["text"] + tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="pt") + + example["pixel_values"] = self.image_processor(path, return_tensors="pt")["pixel_values"] + example["labels"] = tokenized_inputs["input_ids"] + example["attention_mask"] = tokenized_inputs["attention_mask"] + example["decoder_input_ids"] = shift_tokens_right( + example["labels"], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + ) + + return example + + def batch_process(self, examples, padding=None, max_length=None): + """ + Process image and tokenize captions for a single data sample. + + Args: + examples: A batch of data examples containing the images and their captions + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + + Returns: + A dict of pixel values tensor of the processed images and labels token ids and attention masks. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + paths = [example["path"] for example in examples] + texts = [example["text"] for example in examples] + + tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="pt") + + examples["pixel_values"] = self.image_processor(paths, return_tensors="pt")["pixel_values"] + examples["labels"] = tokenized_inputs["input_ids"] + examples["attention_mask"] = tokenized_inputs["attention_mask"] + examples["decoder_input_ids"] = shift_tokens_right( + examples["labels"], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + ) + + return examples From f29e9d2748b90f5681c3a0f4ab0f40d58668e0bb Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 18 Oct 2024 18:39:51 +0330 Subject: [PATCH 03/33] :sparkles: Add `TextClassificationDatasetProcessor` to dataset processors --- .../text_classification_processor.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 hezar/data/dataset_processors/text_classification_processor.py diff --git a/hezar/data/dataset_processors/text_classification_processor.py b/hezar/data/dataset_processors/text_classification_processor.py new file mode 100644 index 00000000..4ec90d8a --- /dev/null +++ b/hezar/data/dataset_processors/text_classification_processor.py @@ -0,0 +1,59 @@ +import torch + +from .dataset_processor import DatasetProcessor + + +class TextClassificationDatasetProcessor(DatasetProcessor): + def __init__(self, tokenizer, id2label, batched=False, max_length=None, padding=None): + super().__init__(batched=batched) + self.tokenizer = tokenizer + self.id2label = id2label + self.label2id = {v: k for k, v in self.id2label.items()} + self.padding = padding + self.max_length = max_length + + def process(self, example, padding=None, max_length=None): + """ + Process a single example. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + text = example["text"] + label = example["label"] + + inputs = self.tokenizer( + text, + return_tensors="torch", + truncation="longest_first", + padding=padding, + max_length=max_length, + return_attention_mask=True, + ) + example.update(inputs) + example["labels"] = torch.tensor([label], dtype=torch.long) + + return example + + def batch_process(self, examples, padding=None, max_length=None): + """ + Process a batch of examples. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + texts = examples["text"] + labels = examples["label"] + + inputs = self.tokenizer( + texts, + return_tensors="torch", + truncation=True, + padding=padding, + max_length=max_length, + return_attention_mask=True, + ) + examples.update(inputs) + examples["labels"] = torch.tensor(labels, dtype=torch.long) + + return examples From 0bef124ee803ad5049a2186f40c52982377db5af Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 17:59:10 +0330 Subject: [PATCH 04/33] :sparkles: Improve dataset processor logic for handling batched mode --- hezar/data/dataset_processors/__init__.py | 1 + .../dataset_processors/dataset_processor.py | 36 +++++++------ .../image_captioning_processor.py | 50 ++++++++++--------- .../text_classification_processor.py | 30 ++++++----- 4 files changed, 62 insertions(+), 55 deletions(-) diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index f40f7425..5651e297 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -1,2 +1,3 @@ from .dataset_processor import DatasetProcessor +from .image_captioning_processor import ImageCaptioningDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor diff --git a/hezar/data/dataset_processors/dataset_processor.py b/hezar/data/dataset_processors/dataset_processor.py index 57d6f924..37e94fe6 100644 --- a/hezar/data/dataset_processors/dataset_processor.py +++ b/hezar/data/dataset_processors/dataset_processor.py @@ -12,42 +12,48 @@ >>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) """ +from ...constants import Backends +from ...utils import is_backend_available, verify_dependencies + + +if is_backend_available(Backends.DATASETS): + from datasets.formatting.formatting import LazyBatch, LazyRow class DatasetProcessor: - def __init__(self, batched=False, *args, **kwargs): + required_backends = [Backends.DATASETS] + + def __init__(self, *args, **kwargs): """ Base constructor that accepts a `batched` flag and any other arguments for child class initialization. - - Args: - batched (bool): Whether to process data in batches or not. """ - self.batched = batched + verify_dependencies(self, self.required_backends) self.args = args self.kwargs = kwargs - def __call__(self, examples, **fn_kwargs): + def __call__(self, data: LazyBatch | LazyRow, **kwargs): """ Method called when using the map function. - Decides whether to call `process()` or `batch_process()` based on the `batched` flag. + Decides whether to call `process_single()` or `process_batch()` based on the data values. Args: - examples (dict or list of dict): Data to process. - **fn_kwargs: Additional keyword arguments passed through the `map` function as `fn_kwargs`. - For example, `fn_kwargs` can contain custom settings like `sampling_rate`. + data: A dict of feature name -> sample or batch of samples mapping. + **kwargs: Additional keyword arguments passed through the `map` function as `kwargs`. """ - if self.batched: - return self.batch_process(examples, **fn_kwargs) + if isinstance(data, LazyRow): + return self.process_single(data, **kwargs) + elif isinstance(data, LazyBatch): + return self.process_batch(data, **kwargs) else: - return self.process(examples, **fn_kwargs) + raise ValueError(f"The input data must be either `LazyBatch` or `LazyRow`, got `{type(data)}`!") - def process(self, example, **kwargs): + def process_single(self, data: LazyRow, **kwargs): """ Method to process a single example """ raise NotImplementedError - def batch_process(self, examples, **kwargs): + def process_batch(self, data: LazyBatch, **kwargs): """ Method to process a batch of examples. """ diff --git a/hezar/data/dataset_processors/image_captioning_processor.py b/hezar/data/dataset_processors/image_captioning_processor.py index a8f3f9fa..1f6d2a85 100644 --- a/hezar/data/dataset_processors/image_captioning_processor.py +++ b/hezar/data/dataset_processors/image_captioning_processor.py @@ -3,19 +3,19 @@ class ImageCaptioningDatasetProcessor(DatasetProcessor): - def __init__(self, image_processor, tokenizer, batched=False, max_length=None, padding=None): - super().__init__(batched=batched) + def __init__(self, image_processor, tokenizer, max_length=None, padding=None): + super().__init__() self.image_processor = image_processor self.tokenizer = tokenizer self.max_length = max_length self.padding = padding - def process(self, example, padding=None, max_length=None): + def process_single(self, data, padding=None, max_length=None): """ Process image and tokenize captions for a single data sample. Args: - example: A data example containing the image and its caption + data: A data example containing the image and its caption padding: Padding type e.g, max_length, longest. max_length: Max length value if padding is set to max_length or the labels must be truncated. @@ -25,26 +25,28 @@ def process(self, example, padding=None, max_length=None): padding = padding or self.padding max_length = max_length or self.max_length - path, text = example["path"], example["text"] - tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="pt") + path = data["image_path"] + text = data["label"] - example["pixel_values"] = self.image_processor(path, return_tensors="pt")["pixel_values"] - example["labels"] = tokenized_inputs["input_ids"] - example["attention_mask"] = tokenized_inputs["attention_mask"] - example["decoder_input_ids"] = shift_tokens_right( - example["labels"], + tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="torch") + + data["pixel_values"] = self.image_processor(path, return_tensors="torch")["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = shift_tokens_right( + data["labels"], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, ) - return example + return data - def batch_process(self, examples, padding=None, max_length=None): + def process_batch(self, data, padding=None, max_length=None): """ - Process image and tokenize captions for a single data sample. + Process image and tokenize captions for a batch of data samples. Args: - examples: A batch of data examples containing the images and their captions + data: A batch of data examples containing the images and their captions padding: Padding type e.g, max_length, longest. max_length: Max length value if padding is set to max_length or the labels must be truncated. @@ -54,18 +56,18 @@ def batch_process(self, examples, padding=None, max_length=None): padding = padding or self.padding max_length = max_length or self.max_length - paths = [example["path"] for example in examples] - texts = [example["text"] for example in examples] + paths = data["image_path"] + texts = data["label"] - tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="pt") + tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="torch") - examples["pixel_values"] = self.image_processor(paths, return_tensors="pt")["pixel_values"] - examples["labels"] = tokenized_inputs["input_ids"] - examples["attention_mask"] = tokenized_inputs["attention_mask"] - examples["decoder_input_ids"] = shift_tokens_right( - examples["labels"], + data["pixel_values"] = self.image_processor(paths, return_tensors="torch")["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = shift_tokens_right( + data["labels"], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, ) - return examples + return data diff --git a/hezar/data/dataset_processors/text_classification_processor.py b/hezar/data/dataset_processors/text_classification_processor.py index 4ec90d8a..6e88b62a 100644 --- a/hezar/data/dataset_processors/text_classification_processor.py +++ b/hezar/data/dataset_processors/text_classification_processor.py @@ -4,23 +4,21 @@ class TextClassificationDatasetProcessor(DatasetProcessor): - def __init__(self, tokenizer, id2label, batched=False, max_length=None, padding=None): - super().__init__(batched=batched) + def __init__(self, tokenizer, max_length=None, padding=None): + super().__init__() self.tokenizer = tokenizer - self.id2label = id2label - self.label2id = {v: k for k, v in self.id2label.items()} self.padding = padding self.max_length = max_length - def process(self, example, padding=None, max_length=None): + def process_single(self, data, padding=None, max_length=None): """ Process a single example. """ padding = padding or self.padding max_length = max_length or self.max_length - text = example["text"] - label = example["label"] + text = data["text"] + label = data["label"] inputs = self.tokenizer( text, @@ -30,20 +28,20 @@ def process(self, example, padding=None, max_length=None): max_length=max_length, return_attention_mask=True, ) - example.update(inputs) - example["labels"] = torch.tensor([label], dtype=torch.long) + data.update(inputs) + data["labels"] = torch.tensor([label], dtype=torch.long) - return example + return data - def batch_process(self, examples, padding=None, max_length=None): + def process_batch(self, data, padding=None, max_length=None): """ Process a batch of examples. """ padding = padding or self.padding max_length = max_length or self.max_length - texts = examples["text"] - labels = examples["label"] + texts = data["text"] + labels = data["label"] inputs = self.tokenizer( texts, @@ -53,7 +51,7 @@ def batch_process(self, examples, padding=None, max_length=None): max_length=max_length, return_attention_mask=True, ) - examples.update(inputs) - examples["labels"] = torch.tensor(labels, dtype=torch.long) + data.update(inputs) + data["labels"] = torch.tensor(labels, dtype=torch.long) - return examples + return data From 687301a183ff4658d030297ea0765ee9e63ceb8d Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 17:59:44 +0330 Subject: [PATCH 05/33] :sparkles: Add sequence labeling dataset processor --- hezar/data/dataset_processors/__init__.py | 1 + .../sequence_labeling_processor.py | 104 ++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 hezar/data/dataset_processors/sequence_labeling_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index 5651e297..09015d36 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -1,3 +1,4 @@ from .dataset_processor import DatasetProcessor from .image_captioning_processor import ImageCaptioningDatasetProcessor +from .sequence_labeling_processor import SequenceLabelingDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor diff --git a/hezar/data/dataset_processors/sequence_labeling_processor.py b/hezar/data/dataset_processors/sequence_labeling_processor.py new file mode 100644 index 00000000..3c60e893 --- /dev/null +++ b/hezar/data/dataset_processors/sequence_labeling_processor.py @@ -0,0 +1,104 @@ +import torch + +from .dataset_processor import DatasetProcessor + + +class SequenceLabelingDatasetProcessor(DatasetProcessor): + def __init__(self, tokenizer, label_all_tokens=False, ignore_index=-100, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.label_all_tokens = label_all_tokens + self.ignore_index = ignore_index + self.max_length = max_length + self.padding = padding + + def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): + """ + Tokenize and align tokens and labels for sequence labeling tasks. + + Args: + tokens: List of tokens (for single examples) or list of lists (for batches). + labels: List of labels (for single examples) or list of lists (for batches). + padding: Padding strategy for tokenization. + max_length: Maximum sequence length to truncate/pad. + + Returns: + dict: Tokenized and aligned inputs with labels. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + # Tokenize and return word IDs for mapping labels to subword tokens + tokenized_inputs = self.tokenizer( + tokens, + is_split_into_words=True, + return_word_ids=True, + padding=padding, + truncation=True, + max_length=max_length, + return_tensors="torch" + ) + word_ids = tokenized_inputs["word_ids"] + + # Align labels with tokens + aligned_labels = [] + for batch_idx, batch_word_ids in enumerate(word_ids): + previous_word_idx = None + label_ids = [] + for word_idx in batch_word_ids: + # Assign ignore index for special tokens + if word_idx is None: + label_ids.append(self.ignore_index) + elif word_idx != previous_word_idx: + # Assign the label for the first token of each word + label_ids.append(labels[batch_idx][word_idx]) + else: + # Assign label for subword tokens (if label_all_tokens is True) + label_ids.append(labels[batch_idx][word_idx] if self.label_all_tokens else self.ignore_index) + previous_word_idx = word_idx + aligned_labels.append(label_ids) + + tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) + return tokenized_inputs + + def process_single(self, data, padding=None, max_length=None): + """ + Process a single example of sequence labeling data. + + Args: + data: A single data example containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned input data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align([tokens], [labels], padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data + + def process_batch(self, data, padding=None, max_length=None): + """ + Process a batch of sequence labeling examples. + + Args: + data: A batch of examples, containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned batch data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align(tokens, labels, padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data From d41d9f186812cdc6c00721533a3976381eebbe0b Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 18:13:35 +0330 Subject: [PATCH 06/33] :sparkles: Add sequence labeling dataset processor --- hezar/data/dataset_processors/__init__.py | 1 + .../sequence_labeling_processor.py | 104 ++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 hezar/data/dataset_processors/sequence_labeling_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index 5651e297..09015d36 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -1,3 +1,4 @@ from .dataset_processor import DatasetProcessor from .image_captioning_processor import ImageCaptioningDatasetProcessor +from .sequence_labeling_processor import SequenceLabelingDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor diff --git a/hezar/data/dataset_processors/sequence_labeling_processor.py b/hezar/data/dataset_processors/sequence_labeling_processor.py new file mode 100644 index 00000000..ec213e6c --- /dev/null +++ b/hezar/data/dataset_processors/sequence_labeling_processor.py @@ -0,0 +1,104 @@ +import torch + +from .dataset_processor import DatasetProcessor + + +class SequenceLabelingDatasetProcessor(DatasetProcessor): + def __init__(self, tokenizer, label_all_tokens=True, ignore_index=-100, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.label_all_tokens = label_all_tokens + self.ignore_index = ignore_index + self.max_length = max_length + self.padding = padding + + def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): + """ + Tokenize and align tokens and labels for sequence labeling tasks. + + Args: + tokens: List of tokens (for single examples) or list of lists (for batches). + labels: List of labels (for single examples) or list of lists (for batches). + padding: Padding strategy for tokenization. + max_length: Maximum sequence length to truncate/pad. + + Returns: + dict: Tokenized and aligned inputs with labels. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + # Tokenize and return word IDs for mapping labels to subword tokens + tokenized_inputs = self.tokenizer( + tokens, + is_split_into_words=True, + return_word_ids=True, + padding=padding, + truncation=True, + max_length=max_length, + return_tensors="torch" + ) + word_ids = tokenized_inputs["word_ids"] + + # Align labels with tokens + aligned_labels = [] + for batch_idx, batch_word_ids in enumerate(word_ids): + previous_word_idx = None + label_ids = [] + for word_idx in batch_word_ids: + # Assign ignore index for special tokens + if word_idx is None: + label_ids.append(self.ignore_index) + elif word_idx != previous_word_idx: + # Assign the label for the first token of each word + label_ids.append(labels[batch_idx][word_idx]) + else: + # Assign label for subword tokens (if label_all_tokens is True) + label_ids.append(labels[batch_idx][word_idx] if self.label_all_tokens else self.ignore_index) + previous_word_idx = word_idx + aligned_labels.append(label_ids) + + tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) + return tokenized_inputs + + def process_single(self, data, padding=None, max_length=None): + """ + Process a single example of sequence labeling data. + + Args: + data: A single data example containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned input data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align([tokens], [labels], padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data + + def process_batch(self, data, padding=None, max_length=None): + """ + Process a batch of sequence labeling examples. + + Args: + data: A batch of examples, containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned batch data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align(tokens, labels, padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data From 64e511df689798a1b6e0f8e7b4b0445510061081 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 18:23:08 +0330 Subject: [PATCH 07/33] :pencil2: Rename `max_target_length` -> `labels_max_length` for text summarization --- hezar/data/data_collators.py | 8 ++++---- hezar/data/datasets/text_summarization_dataset.py | 8 ++++---- tests/test_trainer.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index bef6dcf8..1a685c6d 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -114,7 +114,7 @@ class TextGenerationDataCollator: padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - max_target_length (int): Maximum target length for text generation. + labels_max_length (int): Maximum target length for text generation. return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ @@ -125,14 +125,14 @@ def __init__( padding: str = "longest", padding_side: str = "right", max_length: int = None, - max_target_length: int = None, + labels_max_length: int = None, return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.max_length = max_length - self.max_target_length = max_target_length + self.labels_max_length = labels_max_length self.return_tensors = return_tensors if padding == "longest" and max_length is not None: @@ -169,7 +169,7 @@ def __call__(self, encoded_batch): padded_batch = self.tokenizer.pad_encoded_batch( padded_batch, padding=self.padding, - max_length=self.max_target_length, + max_length=self.labels_max_length, include_keys=["labels"], return_tensors=self.return_tensors, ) diff --git a/hezar/data/datasets/text_summarization_dataset.py b/hezar/data/datasets/text_summarization_dataset.py index b85f1d3f..aaa7be91 100644 --- a/hezar/data/datasets/text_summarization_dataset.py +++ b/hezar/data/datasets/text_summarization_dataset.py @@ -26,7 +26,7 @@ class TextSummarizationDatasetConfig(DatasetConfig): summary_field (str): Field name for summary in the dataset. title_field (str): Field name for title in the dataset. max_length (int): Maximum length of text. - max_target_length (int): Maximum length of the target summary. + labels_max_length (int): Maximum length of the target summary. """ name = "text_summarization" @@ -37,7 +37,7 @@ class TextSummarizationDatasetConfig(DatasetConfig): summary_field: str = None title_field: str = None max_length: int = None - max_target_length: int = None + labels_max_length: int = None @register_dataset("text_summarization", config_class=TextSummarizationDatasetConfig) @@ -58,7 +58,7 @@ def __init__(self, config: TextSummarizationDatasetConfig, split=None, preproces self.data_collator = TextGenerationDataCollator( tokenizer=self.tokenizer, max_length=self.config.max_length, - max_target_length=self.config.max_target_length, + labels_max_length=self.config.labels_max_length, padding="max_length" if self.config.max_length else "longest", ) @@ -103,7 +103,7 @@ def __getitem__(self, index): summary, return_tensors="torch", max_length=self.config.max_length, - padding="max_length" if self.config.max_target_length else "longest", + padding="max_length" if self.config.labels_max_length else "longest", return_attention_mask=True, ) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 6a003cbe..9a2ae80c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -60,7 +60,7 @@ "config": { "max_size": 4, "max_length": 128, - "max_target_length": 32, + "labels_max_length": 32, } }, "model": { From 8f1f0b2a0cb25ccd91c5119905da5a1320ed42e8 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 18:53:33 +0330 Subject: [PATCH 08/33] :sparkles: Add text summarization dataset processor --- hezar/data/dataset_processors/__init__.py | 1 + .../text_summarization_processor.py | 114 ++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 hezar/data/dataset_processors/text_summarization_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index 09015d36..2ae38dd9 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -2,3 +2,4 @@ from .image_captioning_processor import ImageCaptioningDatasetProcessor from .sequence_labeling_processor import SequenceLabelingDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor +from .text_summarization_processor import TextSummarizationDatasetProcessor diff --git a/hezar/data/dataset_processors/text_summarization_processor.py b/hezar/data/dataset_processors/text_summarization_processor.py new file mode 100644 index 00000000..bebd17d0 --- /dev/null +++ b/hezar/data/dataset_processors/text_summarization_processor.py @@ -0,0 +1,114 @@ +from .dataset_processor import DatasetProcessor + + +class TextSummarizationDatasetProcessor(DatasetProcessor): + def __init__( + self, + tokenizer, + prefix=None, + max_length=None, + labels_max_length=None, + text_field="text", + summary_field="summary", + padding="longest", + ): + super().__init__() + self.tokenizer = tokenizer + self.prefix = prefix + self.max_length = max_length + self.labels_max_length = labels_max_length + self.text_field = text_field + self.summary_field = summary_field + self.padding = padding + + def process_single(self, data, padding=None, max_length=None, labels_max_length=None): + """ + Process a single example for text summarization. + + Args: + data: A data example containing text and summary. + padding: Padding strategy. + max_length: Max length for input text. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + text = data[self.text_field] + summary = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + text = self.prefix + text + + # Tokenize inputs and labels + inputs = self.tokenizer( + text, + return_tensors="torch", + max_length=max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + labels = self.tokenizer( + summary, + return_tensors="torch", + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + + inputs["labels"] = labels["token_ids"].clone() + + return inputs + + def process_batch(self, data, padding=None, max_length=None, labels_max_length=None): + """ + Process a batch of examples for text summarization. + + Args: + data: A batch of examples containing texts and summaries. + padding: Padding strategy. + max_length: Max length for input texts. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + texts = data[self.text_field] + summaries = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + texts = [self.prefix + text for text in texts] + + # Tokenize inputs and labels in batch + inputs = self.tokenizer( + texts, + return_tensors="torch", + max_length=max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + labels = self.tokenizer( + summaries, + return_tensors="torch", + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + + inputs["labels"] = labels["token_ids"].clone() + + return inputs From 9c6c0a479753c82c5aca786f579631a278ed607a Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 19:38:00 +0330 Subject: [PATCH 09/33] :sparkles: Add OCR dataset processor --- hezar/data/dataset_processors/__init__.py | 1 + .../data/dataset_processors/ocr_processor.py | 91 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 hezar/data/dataset_processors/ocr_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index 2ae38dd9..ede1dc3e 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -1,5 +1,6 @@ from .dataset_processor import DatasetProcessor from .image_captioning_processor import ImageCaptioningDatasetProcessor +from .ocr_processor import OCRDatasetProcessor from .sequence_labeling_processor import SequenceLabelingDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor from .text_summarization_processor import TextSummarizationDatasetProcessor diff --git a/hezar/data/dataset_processors/ocr_processor.py b/hezar/data/dataset_processors/ocr_processor.py new file mode 100644 index 00000000..ef727379 --- /dev/null +++ b/hezar/data/dataset_processors/ocr_processor.py @@ -0,0 +1,91 @@ +import torch + +from ...utils import reverse_string_digits +from .dataset_processor import DatasetProcessor + + +class OCRDatasetProcessor(DatasetProcessor): + def __init__( + self, + image_processor, + tokenizer=None, + text_split_type="char_split", + max_length=None, + reverse_digits=False, + id2label=None, + image_field="image_path", + text_field="text", + ): + super().__init__() + self.image_processor = image_processor + self.tokenizer = tokenizer + self.text_split_type = text_split_type + self.max_length = max_length + self.reverse_digits = reverse_digits + self.id2label = id2label + self.image_field = image_field + self.text_field = text_field + + def _text_to_tensor(self, text): + """ + Convert text to tensor based on the configured text_split_type. + + Args: + text (str): The raw text. + + Returns: + torch.Tensor: The output tensor. + + """ + if self.text_split_type == "tokenize": + token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"] + token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] + labels = torch.tensor(token_ids) + elif self.text_split_type == "char_split": + if self.reverse_digits: + text = reverse_string_digits(text) + label2id = {v: k for k, v in self.id2label.items()} + labels = [label2id[char] for char in text] + labels = torch.LongTensor(labels) + else: + raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") + return labels + + def process_single(self, data): + """ + Process a single image-to-text OCR example. + + Args: + data: A data example containing an image path and corresponding text. + + Returns: + dict: Processed inputs with pixel values and text labels. + """ + path = data[self.image_field] + text = data[self.text_field] + pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] + labels = self._text_to_tensor(text) + return {"pixel_values": pixel_values, "labels": labels} + + def process_batch(self, data): + """ + Process a batch of image-to-text OCR examples. + + Args: + data: A batch of data examples containing image paths and corresponding texts. + + Returns: + dict: Batch of processed inputs with pixel values and text labels. + """ + paths = data[self.image_field] + texts = data[self.text_field] + + # Process images in batch + pixel_values = self.image_processor(paths, return_tensors="torch")["pixel_values"] + + # Process text labels in batch + labels = [] + for text in texts: + labels.append(self._text_to_tensor(text)) + + return {"pixel_values": pixel_values, "labels": labels} From 18e67e82bb0a6f18948de74bafe565f68617a431 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 20:20:29 +0330 Subject: [PATCH 10/33] :sparkles: Add speech recognition dataset processor --- hezar/data/dataset_processors/__init__.py | 1 + .../speech_recognition_processor.py | 98 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 hezar/data/dataset_processors/speech_recognition_processor.py diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py index ede1dc3e..3d0c81a4 100644 --- a/hezar/data/dataset_processors/__init__.py +++ b/hezar/data/dataset_processors/__init__.py @@ -2,5 +2,6 @@ from .image_captioning_processor import ImageCaptioningDatasetProcessor from .ocr_processor import OCRDatasetProcessor from .sequence_labeling_processor import SequenceLabelingDatasetProcessor +from .speech_recognition_processor import SpeechRecognitionDatasetProcessor from .text_classification_processor import TextClassificationDatasetProcessor from .text_summarization_processor import TextSummarizationDatasetProcessor diff --git a/hezar/data/dataset_processors/speech_recognition_processor.py b/hezar/data/dataset_processors/speech_recognition_processor.py new file mode 100644 index 00000000..1600d4a1 --- /dev/null +++ b/hezar/data/dataset_processors/speech_recognition_processor.py @@ -0,0 +1,98 @@ +from .dataset_processor import DatasetProcessor + + +class SpeechRecognitionDatasetProcessor(DatasetProcessor): + def __init__( + self, + feature_extractor, + tokenizer, + sampling_rate=16000, + audio_array_padding="longest", + max_audio_array_length=None, + labels_padding="longest", + labels_max_length=None, + audio_field="audio", + transcript_field="transcript", + ): + super().__init__() + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.sampling_rate = sampling_rate + self.audio_array_padding = audio_array_padding + self.max_audio_array_length = max_audio_array_length + self.labels_padding = labels_padding + self.labels_max_length = labels_max_length + self.audio_field = audio_field + self.transcript_field = transcript_field + + def process_single(self, data): + """ + Process a single speech recognition example. + + Args: + data: A data example containing audio and its transcript. + + Returns: + dict: Processed input features and labels. + """ + audio_array = data[self.audio_field]["array"] + transcript = data[self.transcript_field] + + # Extract input features from audio + input_features = self.feature_extractor( + audio_array, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors="torch", + )["input_features"] + + # Tokenize the transcript + labels = self.tokenizer( + transcript, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors="torch", + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + + def process_batch(self, data): + """ + Process a batch of speech recognition examples. + + Args: + data: A batch of data examples containing audio arrays and their corresponding transcripts. + + Returns: + dict: Batch of processed input features and labels. + """ + audio_arrays = [x["array"] for x in data[self.audio_field]] + transcripts = data[self.transcript_field] + + # Extract input features in batch + input_features = self.feature_extractor( + audio_arrays, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors="torch", + )["input_features"] + + # Tokenize transcripts in batch + labels = self.tokenizer( + transcripts, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors="torch", + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data From bb412f42b64519a3caebdb8be27a53951d8f8360 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 20:49:14 +0330 Subject: [PATCH 11/33] :fire: Move all dataset processors to one file (`dataset_processors.py`) --- hezar/data/dataset_processors.py | 619 ++++++++++++++++++ hezar/data/dataset_processors/__init__.py | 7 - .../dataset_processors/dataset_processor.py | 60 -- .../image_captioning_processor.py | 73 --- .../data/dataset_processors/ocr_processor.py | 91 --- .../sequence_labeling_processor.py | 104 --- .../speech_recognition_processor.py | 98 --- .../text_classification_processor.py | 57 -- .../text_summarization_processor.py | 114 ---- 9 files changed, 619 insertions(+), 604 deletions(-) create mode 100644 hezar/data/dataset_processors.py delete mode 100644 hezar/data/dataset_processors/__init__.py delete mode 100644 hezar/data/dataset_processors/dataset_processor.py delete mode 100644 hezar/data/dataset_processors/image_captioning_processor.py delete mode 100644 hezar/data/dataset_processors/ocr_processor.py delete mode 100644 hezar/data/dataset_processors/sequence_labeling_processor.py delete mode 100644 hezar/data/dataset_processors/speech_recognition_processor.py delete mode 100644 hezar/data/dataset_processors/text_classification_processor.py delete mode 100644 hezar/data/dataset_processors/text_summarization_processor.py diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py new file mode 100644 index 00000000..744fe428 --- /dev/null +++ b/hezar/data/dataset_processors.py @@ -0,0 +1,619 @@ +""" +Dataset processors are a bunch of callable classes to be passed as map functions for any dataset on the Hub. +Note that the main dataset classes are already implemented in a way that the processing is done in the `__getitem__` +method and these classes are only used for when the dataset has been loaded using the HuggingFace datasets library, +and you want to get advantage of the multiprocessing/batch processing/caching functionalities of the HF datasets. + +Example: +>>> from datasets import load_dataset +>>> from hezar.data import SpeechRecognitionDatasetProcessor + +>>> data_processor = SpeechRecognitionDatasetProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) +>>> dataset = load_dataset("hezarai/common-voice-13-fa") +>>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) +""" +import torch + +from ..constants import Backends +from ..utils import is_backend_available, reverse_string_digits, shift_tokens_right, verify_dependencies + + +if is_backend_available(Backends.DATASETS): + from datasets.formatting.formatting import LazyBatch, LazyRow + +__all__ = [ + "DatasetProcessor", + "ImageCaptioningDatasetProcessor", + "OCRDatasetProcessor", + "SequenceLabelingDatasetProcessor", + "SpeechRecognitionDatasetProcessor", + "TextClassificationDatasetProcessor", + "TextSummarizationDatasetProcessor", +] + + +class DatasetProcessor: + """ + The base callable dataset processor class that can handle both single and batched mode dataset mapping. + """ + required_backends = [Backends.DATASETS] + + def __init__(self, *args, **kwargs): + verify_dependencies(self, self.required_backends) + self.args = args + self.kwargs = kwargs + + def __call__(self, data: LazyBatch | LazyRow, **kwargs): + """ + Method called when using the map function. + Decides whether to call `process_single()` or `process_batch()` based on the data values. + + Args: + data: A dict of feature name -> sample or batch of samples mapping. + **kwargs: Additional keyword arguments passed through the `map` function as `kwargs`. + """ + if isinstance(data, LazyRow): + return self.process_single(data, **kwargs) + elif isinstance(data, LazyBatch): + return self.process_batch(data, **kwargs) + else: + raise ValueError(f"The input data must be either `LazyBatch` or `LazyRow`, got `{type(data)}`!") + + def process_single(self, data: LazyRow, **kwargs): + """ + Method to process a single example + """ + raise NotImplementedError + + def process_batch(self, data: LazyBatch, **kwargs): + """ + Method to process a batch of examples. + """ + raise NotImplementedError + + +class ImageCaptioningDatasetProcessor(DatasetProcessor): + """ + Dataset processor for image captioning datasets. This class handles tokenization and image processing. + """ + + def __init__(self, image_processor, tokenizer, max_length=None, padding=None): + super().__init__() + self.image_processor = image_processor + self.tokenizer = tokenizer + self.max_length = max_length + self.padding = padding + + def process_single(self, data, padding=None, max_length=None): + """ + Process image and tokenize captions for a single data sample. + + Args: + data: A data example containing the image and its caption + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + + Returns: + A dict of pixel values tensor of the processed image and labels token ids and attention mask. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + path = data["image_path"] + text = data["label"] + + tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="torch") + + data["pixel_values"] = self.image_processor(path, return_tensors="torch")["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = shift_tokens_right( + data["labels"], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + ) + + return data + + def process_batch(self, data, padding=None, max_length=None): + """ + Process image and tokenize captions for a batch of data samples. + + Args: + data: A batch of data examples containing the images and their captions + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + + Returns: + A dict of pixel values tensor of the processed images and labels token ids and attention masks. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + paths = data["image_path"] + texts = data["label"] + + tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="torch") + + data["pixel_values"] = self.image_processor(paths, return_tensors="torch")["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = shift_tokens_right( + data["labels"], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + ) + + return data + + +class OCRDatasetProcessor(DatasetProcessor): + """ + Dataset processor class for OCR which can handle both tokenizer-based or character-split-based datasets. + """ + + def __init__( + self, + image_processor, + tokenizer=None, + text_split_type="char_split", + max_length=None, + reverse_digits=False, + id2label=None, + image_field="image_path", + text_field="text", + ): + super().__init__() + self.image_processor = image_processor + self.tokenizer = tokenizer + self.text_split_type = text_split_type + self.max_length = max_length + self.reverse_digits = reverse_digits + self.id2label = id2label + self.image_field = image_field + self.text_field = text_field + + def _text_to_tensor(self, text): + """ + Convert text to tensor based on the configured text_split_type. + + Args: + text (str): The raw text. + + Returns: + torch.Tensor: The output tensor. + + """ + if self.text_split_type == "tokenize": + token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"] + token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] + labels = torch.tensor(token_ids) + elif self.text_split_type == "char_split": + if self.reverse_digits: + text = reverse_string_digits(text) + label2id = {v: k for k, v in self.id2label.items()} + labels = [label2id[char] for char in text] + labels = torch.LongTensor(labels) + else: + raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") + return labels + + def process_single(self, data): + """ + Process a single image-to-text OCR example. + + Args: + data: A data example containing an image path and corresponding text. + + Returns: + dict: Processed inputs with pixel values and text labels. + """ + path = data[self.image_field] + text = data[self.text_field] + pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] + labels = self._text_to_tensor(text) + return {"pixel_values": pixel_values, "labels": labels} + + def process_batch(self, data): + """ + Process a batch of image-to-text OCR examples. + + Args: + data: A batch of data examples containing image paths and corresponding texts. + + Returns: + dict: Batch of processed inputs with pixel values and text labels. + """ + paths = data[self.image_field] + texts = data[self.text_field] + + # Process images in batch + pixel_values = self.image_processor(paths, return_tensors="torch")["pixel_values"] + + # Process text labels in batch + labels = [] + for text in texts: + labels.append(self._text_to_tensor(text)) + + return {"pixel_values": pixel_values, "labels": labels} + + +class SequenceLabelingDatasetProcessor(DatasetProcessor): + """ + Dataset processor class for sequence labeling datasets. Handles tokenization and label alignment. + """ + + def __init__(self, tokenizer, label_all_tokens=True, ignore_index=-100, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.label_all_tokens = label_all_tokens + self.ignore_index = ignore_index + self.max_length = max_length + self.padding = padding + + def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): + """ + Tokenize and align tokens and labels for sequence labeling tasks. + + Args: + tokens: List of tokens (for single examples) or list of lists (for batches). + labels: List of labels (for single examples) or list of lists (for batches). + padding: Padding strategy for tokenization. + max_length: Maximum sequence length to truncate/pad. + + Returns: + dict: Tokenized and aligned inputs with labels. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + # Tokenize and return word IDs for mapping labels to subword tokens + tokenized_inputs = self.tokenizer( + tokens, + is_split_into_words=True, + return_word_ids=True, + padding=padding, + truncation=True, + max_length=max_length, + return_tensors="torch" + ) + word_ids = tokenized_inputs["word_ids"] + + # Align labels with tokens + aligned_labels = [] + for batch_idx, batch_word_ids in enumerate(word_ids): + previous_word_idx = None + label_ids = [] + for word_idx in batch_word_ids: + # Assign ignore index for special tokens + if word_idx is None: + label_ids.append(self.ignore_index) + elif word_idx != previous_word_idx: + # Assign the label for the first token of each word + label_ids.append(labels[batch_idx][word_idx]) + else: + # Assign label for subword tokens (if label_all_tokens is True) + label_ids.append(labels[batch_idx][word_idx] if self.label_all_tokens else self.ignore_index) + previous_word_idx = word_idx + aligned_labels.append(label_ids) + + tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) + return tokenized_inputs + + def process_single(self, data, padding=None, max_length=None): + """ + Process a single example of sequence labeling data. + + Args: + data: A single data example containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned input data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align([tokens], [labels], padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data + + def process_batch(self, data, padding=None, max_length=None): + """ + Process a batch of sequence labeling examples. + + Args: + data: A batch of examples, containing tokens and labels. + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned batch data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align(tokens, labels, padding=padding, max_length=max_length) + + data.update(tokenized_inputs) + + return data + + +class SpeechRecognitionDatasetProcessor(DatasetProcessor): + """ + Processor class for speech recognition datasets. Handles audio feature extraction and labels tokenization. + """ + + def __init__( + self, + feature_extractor, + tokenizer, + sampling_rate=16000, + audio_array_padding="longest", + max_audio_array_length=None, + labels_padding="longest", + labels_max_length=None, + audio_field="audio", + transcript_field="transcript", + ): + super().__init__() + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.sampling_rate = sampling_rate + self.audio_array_padding = audio_array_padding + self.max_audio_array_length = max_audio_array_length + self.labels_padding = labels_padding + self.labels_max_length = labels_max_length + self.audio_field = audio_field + self.transcript_field = transcript_field + + def process_single(self, data): + """ + Process a single speech recognition example. + + Args: + data: A data example containing audio and its transcript. + + Returns: + dict: Processed input features and labels. + """ + audio_array = data[self.audio_field]["array"] + transcript = data[self.transcript_field] + + # Extract input features from audio + input_features = self.feature_extractor( + audio_array, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors="torch", + )["input_features"] + + # Tokenize the transcript + labels = self.tokenizer( + transcript, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors="torch", + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + + def process_batch(self, data): + """ + Process a batch of speech recognition examples. + + Args: + data: A batch of data examples containing audio arrays and their corresponding transcripts. + + Returns: + dict: Batch of processed input features and labels. + """ + audio_arrays = [x["array"] for x in data[self.audio_field]] + transcripts = data[self.transcript_field] + + # Extract input features in batch + input_features = self.feature_extractor( + audio_arrays, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors="torch", + )["input_features"] + + # Tokenize transcripts in batch + labels = self.tokenizer( + transcripts, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors="torch", + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + + +class TextClassificationDatasetProcessor(DatasetProcessor): + """ + Processor class for text classification datasets. Handles tokenization of the texts. + """ + + def __init__(self, tokenizer, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.padding = padding + self.max_length = max_length + + def process_single(self, data, padding=None, max_length=None): + """ + Process a single example. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + text = data["text"] + label = data["label"] + + inputs = self.tokenizer( + text, + return_tensors="torch", + truncation="longest_first", + padding=padding, + max_length=max_length, + return_attention_mask=True, + ) + data.update(inputs) + data["labels"] = torch.tensor([label], dtype=torch.long) + + return data + + def process_batch(self, data, padding=None, max_length=None): + """ + Process a batch of examples. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + texts = data["text"] + labels = data["label"] + + inputs = self.tokenizer( + texts, + return_tensors="torch", + truncation=True, + padding=padding, + max_length=max_length, + return_attention_mask=True, + ) + data.update(inputs) + data["labels"] = torch.tensor(labels, dtype=torch.long) + + return data + + +class TextSummarizationDatasetProcessor(DatasetProcessor): + """ + Processor class for text summarization datasets. Handles tokenization of the inputs and labels. + """ + + def __init__( + self, + tokenizer, + prefix=None, + max_length=None, + labels_max_length=None, + text_field="text", + summary_field="summary", + padding="longest", + ): + super().__init__() + self.tokenizer = tokenizer + self.prefix = prefix + self.max_length = max_length + self.labels_max_length = labels_max_length + self.text_field = text_field + self.summary_field = summary_field + self.padding = padding + + def process_single(self, data, padding=None, max_length=None, labels_max_length=None): + """ + Process a single example for text summarization. + + Args: + data: A data example containing text and summary. + padding: Padding strategy. + max_length: Max length for input text. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + text = data[self.text_field] + summary = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + text = self.prefix + text + + # Tokenize inputs and labels + inputs = self.tokenizer( + text, + return_tensors="torch", + max_length=max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + labels = self.tokenizer( + summary, + return_tensors="torch", + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + + inputs["labels"] = labels["token_ids"].clone() + + return inputs + + def process_batch(self, data, padding=None, max_length=None, labels_max_length=None): + """ + Process a batch of examples for text summarization. + + Args: + data: A batch of examples containing texts and summaries. + padding: Padding strategy. + max_length: Max length for input texts. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + texts = data[self.text_field] + summaries = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + texts = [self.prefix + text for text in texts] + + # Tokenize inputs and labels in batch + inputs = self.tokenizer( + texts, + return_tensors="torch", + max_length=max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + labels = self.tokenizer( + summaries, + return_tensors="torch", + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + truncation=True + ) + + inputs["labels"] = labels["token_ids"].clone() + + return inputs diff --git a/hezar/data/dataset_processors/__init__.py b/hezar/data/dataset_processors/__init__.py deleted file mode 100644 index 3d0c81a4..00000000 --- a/hezar/data/dataset_processors/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .dataset_processor import DatasetProcessor -from .image_captioning_processor import ImageCaptioningDatasetProcessor -from .ocr_processor import OCRDatasetProcessor -from .sequence_labeling_processor import SequenceLabelingDatasetProcessor -from .speech_recognition_processor import SpeechRecognitionDatasetProcessor -from .text_classification_processor import TextClassificationDatasetProcessor -from .text_summarization_processor import TextSummarizationDatasetProcessor diff --git a/hezar/data/dataset_processors/dataset_processor.py b/hezar/data/dataset_processors/dataset_processor.py deleted file mode 100644 index 37e94fe6..00000000 --- a/hezar/data/dataset_processors/dataset_processor.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Dataset processors are a bunch of callable classes to be passed to be used as map functions for any dataset on the Hub. -Note that the main dataset classes are already implemented in a way that the processing is done in the `__getitem__` -method and these classes are only used for when the dataset has been loaded using HuggingFace datasets library. - -Example: ->>> from datasets import load_dataset ->>> from hezar.data import SpeechRecognitionDatasetProcessor - ->>> data_processor = SpeechRecognitionDatasetProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) ->>> dataset = load_dataset("hezarai/common-voice-13-fa") ->>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) - -""" -from ...constants import Backends -from ...utils import is_backend_available, verify_dependencies - - -if is_backend_available(Backends.DATASETS): - from datasets.formatting.formatting import LazyBatch, LazyRow - - -class DatasetProcessor: - required_backends = [Backends.DATASETS] - - def __init__(self, *args, **kwargs): - """ - Base constructor that accepts a `batched` flag and any other arguments for child class initialization. - """ - verify_dependencies(self, self.required_backends) - self.args = args - self.kwargs = kwargs - - def __call__(self, data: LazyBatch | LazyRow, **kwargs): - """ - Method called when using the map function. - Decides whether to call `process_single()` or `process_batch()` based on the data values. - - Args: - data: A dict of feature name -> sample or batch of samples mapping. - **kwargs: Additional keyword arguments passed through the `map` function as `kwargs`. - """ - if isinstance(data, LazyRow): - return self.process_single(data, **kwargs) - elif isinstance(data, LazyBatch): - return self.process_batch(data, **kwargs) - else: - raise ValueError(f"The input data must be either `LazyBatch` or `LazyRow`, got `{type(data)}`!") - - def process_single(self, data: LazyRow, **kwargs): - """ - Method to process a single example - """ - raise NotImplementedError - - def process_batch(self, data: LazyBatch, **kwargs): - """ - Method to process a batch of examples. - """ - raise NotImplementedError diff --git a/hezar/data/dataset_processors/image_captioning_processor.py b/hezar/data/dataset_processors/image_captioning_processor.py deleted file mode 100644 index 1f6d2a85..00000000 --- a/hezar/data/dataset_processors/image_captioning_processor.py +++ /dev/null @@ -1,73 +0,0 @@ -from ...utils import shift_tokens_right -from .dataset_processor import DatasetProcessor - - -class ImageCaptioningDatasetProcessor(DatasetProcessor): - def __init__(self, image_processor, tokenizer, max_length=None, padding=None): - super().__init__() - self.image_processor = image_processor - self.tokenizer = tokenizer - self.max_length = max_length - self.padding = padding - - def process_single(self, data, padding=None, max_length=None): - """ - Process image and tokenize captions for a single data sample. - - Args: - data: A data example containing the image and its caption - padding: Padding type e.g, max_length, longest. - max_length: Max length value if padding is set to max_length or the labels must be truncated. - - Returns: - A dict of pixel values tensor of the processed image and labels token ids and attention mask. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - - path = data["image_path"] - text = data["label"] - - tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="torch") - - data["pixel_values"] = self.image_processor(path, return_tensors="torch")["pixel_values"] - data["labels"] = tokenized_inputs["token_ids"] - data["attention_mask"] = tokenized_inputs["attention_mask"] - data["decoder_input_ids"] = shift_tokens_right( - data["labels"], - pad_token_id=self.tokenizer.pad_token_id, - decoder_start_token_id=self.tokenizer.bos_token_id, - ) - - return data - - def process_batch(self, data, padding=None, max_length=None): - """ - Process image and tokenize captions for a batch of data samples. - - Args: - data: A batch of data examples containing the images and their captions - padding: Padding type e.g, max_length, longest. - max_length: Max length value if padding is set to max_length or the labels must be truncated. - - Returns: - A dict of pixel values tensor of the processed images and labels token ids and attention masks. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - - paths = data["image_path"] - texts = data["label"] - - tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="torch") - - data["pixel_values"] = self.image_processor(paths, return_tensors="torch")["pixel_values"] - data["labels"] = tokenized_inputs["token_ids"] - data["attention_mask"] = tokenized_inputs["attention_mask"] - data["decoder_input_ids"] = shift_tokens_right( - data["labels"], - pad_token_id=self.tokenizer.pad_token_id, - decoder_start_token_id=self.tokenizer.bos_token_id, - ) - - return data diff --git a/hezar/data/dataset_processors/ocr_processor.py b/hezar/data/dataset_processors/ocr_processor.py deleted file mode 100644 index ef727379..00000000 --- a/hezar/data/dataset_processors/ocr_processor.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch - -from ...utils import reverse_string_digits -from .dataset_processor import DatasetProcessor - - -class OCRDatasetProcessor(DatasetProcessor): - def __init__( - self, - image_processor, - tokenizer=None, - text_split_type="char_split", - max_length=None, - reverse_digits=False, - id2label=None, - image_field="image_path", - text_field="text", - ): - super().__init__() - self.image_processor = image_processor - self.tokenizer = tokenizer - self.text_split_type = text_split_type - self.max_length = max_length - self.reverse_digits = reverse_digits - self.id2label = id2label - self.image_field = image_field - self.text_field = text_field - - def _text_to_tensor(self, text): - """ - Convert text to tensor based on the configured text_split_type. - - Args: - text (str): The raw text. - - Returns: - torch.Tensor: The output tensor. - - """ - if self.text_split_type == "tokenize": - token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"] - token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] - labels = torch.tensor(token_ids) - elif self.text_split_type == "char_split": - if self.reverse_digits: - text = reverse_string_digits(text) - label2id = {v: k for k, v in self.id2label.items()} - labels = [label2id[char] for char in text] - labels = torch.LongTensor(labels) - else: - raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") - return labels - - def process_single(self, data): - """ - Process a single image-to-text OCR example. - - Args: - data: A data example containing an image path and corresponding text. - - Returns: - dict: Processed inputs with pixel values and text labels. - """ - path = data[self.image_field] - text = data[self.text_field] - pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] - labels = self._text_to_tensor(text) - return {"pixel_values": pixel_values, "labels": labels} - - def process_batch(self, data): - """ - Process a batch of image-to-text OCR examples. - - Args: - data: A batch of data examples containing image paths and corresponding texts. - - Returns: - dict: Batch of processed inputs with pixel values and text labels. - """ - paths = data[self.image_field] - texts = data[self.text_field] - - # Process images in batch - pixel_values = self.image_processor(paths, return_tensors="torch")["pixel_values"] - - # Process text labels in batch - labels = [] - for text in texts: - labels.append(self._text_to_tensor(text)) - - return {"pixel_values": pixel_values, "labels": labels} diff --git a/hezar/data/dataset_processors/sequence_labeling_processor.py b/hezar/data/dataset_processors/sequence_labeling_processor.py deleted file mode 100644 index ec213e6c..00000000 --- a/hezar/data/dataset_processors/sequence_labeling_processor.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch - -from .dataset_processor import DatasetProcessor - - -class SequenceLabelingDatasetProcessor(DatasetProcessor): - def __init__(self, tokenizer, label_all_tokens=True, ignore_index=-100, max_length=None, padding=None): - super().__init__() - self.tokenizer = tokenizer - self.label_all_tokens = label_all_tokens - self.ignore_index = ignore_index - self.max_length = max_length - self.padding = padding - - def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): - """ - Tokenize and align tokens and labels for sequence labeling tasks. - - Args: - tokens: List of tokens (for single examples) or list of lists (for batches). - labels: List of labels (for single examples) or list of lists (for batches). - padding: Padding strategy for tokenization. - max_length: Maximum sequence length to truncate/pad. - - Returns: - dict: Tokenized and aligned inputs with labels. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - - # Tokenize and return word IDs for mapping labels to subword tokens - tokenized_inputs = self.tokenizer( - tokens, - is_split_into_words=True, - return_word_ids=True, - padding=padding, - truncation=True, - max_length=max_length, - return_tensors="torch" - ) - word_ids = tokenized_inputs["word_ids"] - - # Align labels with tokens - aligned_labels = [] - for batch_idx, batch_word_ids in enumerate(word_ids): - previous_word_idx = None - label_ids = [] - for word_idx in batch_word_ids: - # Assign ignore index for special tokens - if word_idx is None: - label_ids.append(self.ignore_index) - elif word_idx != previous_word_idx: - # Assign the label for the first token of each word - label_ids.append(labels[batch_idx][word_idx]) - else: - # Assign label for subword tokens (if label_all_tokens is True) - label_ids.append(labels[batch_idx][word_idx] if self.label_all_tokens else self.ignore_index) - previous_word_idx = word_idx - aligned_labels.append(label_ids) - - tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) - return tokenized_inputs - - def process_single(self, data, padding=None, max_length=None): - """ - Process a single example of sequence labeling data. - - Args: - data: A single data example containing tokens and labels. - padding: Padding strategy. - max_length: Maximum sequence length. - - Returns: - dict: Tokenized and aligned input data. - """ - tokens = data["tokens"] - labels = data["pos_tags"] - - tokenized_inputs = self._tokenize_and_align([tokens], [labels], padding=padding, max_length=max_length) - - data.update(tokenized_inputs) - - return data - - def process_batch(self, data, padding=None, max_length=None): - """ - Process a batch of sequence labeling examples. - - Args: - data: A batch of examples, containing tokens and labels. - padding: Padding strategy. - max_length: Maximum sequence length. - - Returns: - dict: Tokenized and aligned batch data. - """ - tokens = data["tokens"] - labels = data["pos_tags"] - - tokenized_inputs = self._tokenize_and_align(tokens, labels, padding=padding, max_length=max_length) - - data.update(tokenized_inputs) - - return data diff --git a/hezar/data/dataset_processors/speech_recognition_processor.py b/hezar/data/dataset_processors/speech_recognition_processor.py deleted file mode 100644 index 1600d4a1..00000000 --- a/hezar/data/dataset_processors/speech_recognition_processor.py +++ /dev/null @@ -1,98 +0,0 @@ -from .dataset_processor import DatasetProcessor - - -class SpeechRecognitionDatasetProcessor(DatasetProcessor): - def __init__( - self, - feature_extractor, - tokenizer, - sampling_rate=16000, - audio_array_padding="longest", - max_audio_array_length=None, - labels_padding="longest", - labels_max_length=None, - audio_field="audio", - transcript_field="transcript", - ): - super().__init__() - self.feature_extractor = feature_extractor - self.tokenizer = tokenizer - self.sampling_rate = sampling_rate - self.audio_array_padding = audio_array_padding - self.max_audio_array_length = max_audio_array_length - self.labels_padding = labels_padding - self.labels_max_length = labels_max_length - self.audio_field = audio_field - self.transcript_field = transcript_field - - def process_single(self, data): - """ - Process a single speech recognition example. - - Args: - data: A data example containing audio and its transcript. - - Returns: - dict: Processed input features and labels. - """ - audio_array = data[self.audio_field]["array"] - transcript = data[self.transcript_field] - - # Extract input features from audio - input_features = self.feature_extractor( - audio_array, - sampling_rate=self.sampling_rate, - padding=self.audio_array_padding, - max_length=self.max_audio_array_length, - return_tensors="torch", - )["input_features"] - - # Tokenize the transcript - labels = self.tokenizer( - transcript, - padding=self.labels_padding, - max_length=self.labels_max_length, - return_tensors="torch", - ) - - data["input_features"] = input_features - data["labels"] = labels["token_ids"] - data["attention_mask"] = labels["attention_mask"] - - return data - - def process_batch(self, data): - """ - Process a batch of speech recognition examples. - - Args: - data: A batch of data examples containing audio arrays and their corresponding transcripts. - - Returns: - dict: Batch of processed input features and labels. - """ - audio_arrays = [x["array"] for x in data[self.audio_field]] - transcripts = data[self.transcript_field] - - # Extract input features in batch - input_features = self.feature_extractor( - audio_arrays, - sampling_rate=self.sampling_rate, - padding=self.audio_array_padding, - max_length=self.max_audio_array_length, - return_tensors="torch", - )["input_features"] - - # Tokenize transcripts in batch - labels = self.tokenizer( - transcripts, - padding=self.labels_padding, - max_length=self.labels_max_length, - return_tensors="torch", - ) - - data["input_features"] = input_features - data["labels"] = labels["token_ids"] - data["attention_mask"] = labels["attention_mask"] - - return data diff --git a/hezar/data/dataset_processors/text_classification_processor.py b/hezar/data/dataset_processors/text_classification_processor.py deleted file mode 100644 index 6e88b62a..00000000 --- a/hezar/data/dataset_processors/text_classification_processor.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch - -from .dataset_processor import DatasetProcessor - - -class TextClassificationDatasetProcessor(DatasetProcessor): - def __init__(self, tokenizer, max_length=None, padding=None): - super().__init__() - self.tokenizer = tokenizer - self.padding = padding - self.max_length = max_length - - def process_single(self, data, padding=None, max_length=None): - """ - Process a single example. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - - text = data["text"] - label = data["label"] - - inputs = self.tokenizer( - text, - return_tensors="torch", - truncation="longest_first", - padding=padding, - max_length=max_length, - return_attention_mask=True, - ) - data.update(inputs) - data["labels"] = torch.tensor([label], dtype=torch.long) - - return data - - def process_batch(self, data, padding=None, max_length=None): - """ - Process a batch of examples. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - - texts = data["text"] - labels = data["label"] - - inputs = self.tokenizer( - texts, - return_tensors="torch", - truncation=True, - padding=padding, - max_length=max_length, - return_attention_mask=True, - ) - data.update(inputs) - data["labels"] = torch.tensor(labels, dtype=torch.long) - - return data diff --git a/hezar/data/dataset_processors/text_summarization_processor.py b/hezar/data/dataset_processors/text_summarization_processor.py deleted file mode 100644 index bebd17d0..00000000 --- a/hezar/data/dataset_processors/text_summarization_processor.py +++ /dev/null @@ -1,114 +0,0 @@ -from .dataset_processor import DatasetProcessor - - -class TextSummarizationDatasetProcessor(DatasetProcessor): - def __init__( - self, - tokenizer, - prefix=None, - max_length=None, - labels_max_length=None, - text_field="text", - summary_field="summary", - padding="longest", - ): - super().__init__() - self.tokenizer = tokenizer - self.prefix = prefix - self.max_length = max_length - self.labels_max_length = labels_max_length - self.text_field = text_field - self.summary_field = summary_field - self.padding = padding - - def process_single(self, data, padding=None, max_length=None, labels_max_length=None): - """ - Process a single example for text summarization. - - Args: - data: A data example containing text and summary. - padding: Padding strategy. - max_length: Max length for input text. - labels_max_length: Max length for summary labels. - - Returns: - dict: Tokenized inputs and labels for summarization task. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - labels_max_length = labels_max_length or self.labels_max_length - - text = data[self.text_field] - summary = data[self.summary_field] - - # Add prefix if needed for conditional generation - if self.prefix is not None: - text = self.prefix + text - - # Tokenize inputs and labels - inputs = self.tokenizer( - text, - return_tensors="torch", - max_length=max_length, - padding=padding, - return_attention_mask=True, - truncation=True - ) - labels = self.tokenizer( - summary, - return_tensors="torch", - max_length=labels_max_length, - padding=padding, - return_attention_mask=True, - truncation=True - ) - - inputs["labels"] = labels["token_ids"].clone() - - return inputs - - def process_batch(self, data, padding=None, max_length=None, labels_max_length=None): - """ - Process a batch of examples for text summarization. - - Args: - data: A batch of examples containing texts and summaries. - padding: Padding strategy. - max_length: Max length for input texts. - labels_max_length: Max length for summary labels. - - Returns: - dict: Tokenized inputs and labels for summarization task. - """ - padding = padding or self.padding - max_length = max_length or self.max_length - labels_max_length = labels_max_length or self.labels_max_length - - texts = data[self.text_field] - summaries = data[self.summary_field] - - # Add prefix if needed for conditional generation - if self.prefix is not None: - texts = [self.prefix + text for text in texts] - - # Tokenize inputs and labels in batch - inputs = self.tokenizer( - texts, - return_tensors="torch", - max_length=max_length, - padding=padding, - return_attention_mask=True, - truncation=True - ) - labels = self.tokenizer( - summaries, - return_tensors="torch", - max_length=labels_max_length, - padding=padding, - return_attention_mask=True, - truncation=True - ) - - inputs["labels"] = labels["token_ids"].clone() - - return inputs From 231868042ac30c774655d9fd8621bd4c28aaf3d9 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Sat, 19 Oct 2024 20:50:38 +0330 Subject: [PATCH 12/33] :sparkles: Add `dataset_processing_example.py` --- examples/data/dataset_processing_example.py | 142 ++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 examples/data/dataset_processing_example.py diff --git a/examples/data/dataset_processing_example.py b/examples/data/dataset_processing_example.py new file mode 100644 index 00000000..fde04553 --- /dev/null +++ b/examples/data/dataset_processing_example.py @@ -0,0 +1,142 @@ +# from datasets import load_dataset +# +# from hezar.data import ImageCaptioningDatasetProcessor +# from hezar.preprocessors import Tokenizer, ImageProcessor +# +# dataset = load_dataset("hezarai/flickr30k-fa", split="train").select(indices=list(range(4000))) +# tokenizer = Tokenizer.load("hezarai/vit-roberta-fa-image-captioning-flickr30k") +# image_processor = ImageProcessor.load("hezarai/vit-roberta-fa-image-captioning-flickr30k") +# dataset_processor = ImageCaptioningDatasetProcessor(tokenizer=tokenizer, image_processor=image_processor) +# +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# # num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +# +# from datasets import load_dataset +# +# from hezar.data import TextClassificationDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/sentiment-dksf", split="train") +# tokenizer = Tokenizer.load("hezarai/roberta-base-fa") +# dataset_processor = TextClassificationDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# batched=True, +# batch_size=1000, +# load_from_cache_file=False, +# num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + + +# from datasets import load_dataset +# +# from hezar.data import SequenceLabelingDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/lscp-pos-500k", split="train") +# tokenizer = Tokenizer.load("hezarai/roberta-base-fa") +# dataset_processor = SequenceLabelingDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# batched=True, +# batch_size=1000, +# load_from_cache_file=False, +# # num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + + +# from datasets import load_dataset +# +# from hezar.data import TextSummarizationDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/xlsum-fa", split="train") +# tokenizer = Tokenizer.load("hezarai/t5-base-fa") +# dataset_processor = TextSummarizationDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# num_proc=10, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +# from datasets import load_dataset +# +# from hezar.data import OCRDatasetProcessor +# from hezar.preprocessors import ImageProcessor +# from hezar.configs import ModelConfig +# from hezar.utils import is_text_valid +# +# +# dataset = load_dataset("hezarai/parsynth-ocr-200k", split="train[:3000]") +# id2label = ModelConfig.load("hezarai/crnn-fa-printed-96-long", filename="model_config.yaml")["id2label"] # hack +# +# # Cleanup dataset +# max_length = 48 +# valid_indices = [] +# invalid_indices = [] +# for i, sample in enumerate(list(iter(dataset))): +# path, text = sample.values() +# if len(text) <= max_length and is_text_valid(text, id2label.values()): +# valid_indices.append(i) +# dataset = dataset.select(valid_indices) +# +# image_processor = ImageProcessor.load("hezarai/crnn-fa-printed-96-long") +# dataset_processor = OCRDatasetProcessor(image_processor=image_processor, id2label=id2label) +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# # num_proc=10, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +from datasets import load_dataset + +from hezar.data import SpeechRecognitionDatasetProcessor +from hezar.preprocessors import Tokenizer, AudioFeatureExtractor + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train[:100]") +tokenizer = Tokenizer.load("hezarai/whisper-small-fa") +feature_extractor = AudioFeatureExtractor.load("hezarai/whisper-small-fa") +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=tokenizer, + feature_extractor=feature_extractor, + transcript_field="sentence", +) + +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + load_from_cache_file=False, + # num_proc=10, + desc="Processing dataset..." +) +processed_dataset.set_format("torch") +print(processed_dataset[0]) \ No newline at end of file From 705f1973c1323262c98ba24beceb28fc10947ba7 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Tue, 22 Oct 2024 20:59:52 +0330 Subject: [PATCH 13/33] :sparkles: Add `return_tensors` option to dataset processors --- hezar/data/dataset_processors.py | 150 +++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 49 deletions(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 744fe428..7106ff86 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -43,31 +43,48 @@ def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs - def __call__(self, data: LazyBatch | LazyRow, **kwargs): + def __call__(self, data: LazyBatch | LazyRow, return_tensors="list", **kwargs): """ Method called when using the map function. Decides whether to call `process_single()` or `process_batch()` based on the data values. Args: data: A dict of feature name -> sample or batch of samples mapping. + return_tensors: The type of the returning tensors (list, torch, numpy) **kwargs: Additional keyword arguments passed through the `map` function as `kwargs`. """ if isinstance(data, LazyRow): - return self.process_single(data, **kwargs) + return self.process_single(data, return_tensors=return_tensors, **kwargs) elif isinstance(data, LazyBatch): - return self.process_batch(data, **kwargs) + return self.process_batch(data, return_tensors=return_tensors, **kwargs) else: raise ValueError(f"The input data must be either `LazyBatch` or `LazyRow`, got `{type(data)}`!") - def process_single(self, data: LazyRow, **kwargs): + def process_single(self, data: LazyRow, return_tensors=None, **kwargs): """ - Method to process a single example + Process a single data example. + + Args: + data: A data sample dict + return_tensors: The type of the returning tensors (list, torch, numpy) + **kwargs: Additional arguments + + Returns: + The updated data dict """ raise NotImplementedError - def process_batch(self, data: LazyBatch, **kwargs): + def process_batch(self, data: LazyBatch, return_tensors=None, **kwargs): """ - Method to process a batch of examples. + Process a batch of data examples. + + Args: + data: A data sample dict + return_tensors: The type of the returning tensors (list, torch, numpy) + **kwargs: Additional arguments + + Returns: + The updated data dict """ raise NotImplementedError @@ -84,7 +101,7 @@ def __init__(self, image_processor, tokenizer, max_length=None, padding=None): self.max_length = max_length self.padding = padding - def process_single(self, data, padding=None, max_length=None): + def process_single(self, data, return_tensors=None, padding=None, max_length=None): """ Process image and tokenize captions for a single data sample. @@ -92,6 +109,7 @@ def process_single(self, data, padding=None, max_length=None): data: A data example containing the image and its caption padding: Padding type e.g, max_length, longest. max_length: Max length value if padding is set to max_length or the labels must be truncated. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: A dict of pixel values tensor of the processed image and labels token ids and attention mask. @@ -102,9 +120,9 @@ def process_single(self, data, padding=None, max_length=None): path = data["image_path"] text = data["label"] - tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors="torch") + tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors=return_tensors) - data["pixel_values"] = self.image_processor(path, return_tensors="torch")["pixel_values"] + data["pixel_values"] = self.image_processor(path, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] data["attention_mask"] = tokenized_inputs["attention_mask"] data["decoder_input_ids"] = shift_tokens_right( @@ -115,7 +133,7 @@ def process_single(self, data, padding=None, max_length=None): return data - def process_batch(self, data, padding=None, max_length=None): + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): """ Process image and tokenize captions for a batch of data samples. @@ -123,6 +141,7 @@ def process_batch(self, data, padding=None, max_length=None): data: A batch of data examples containing the images and their captions padding: Padding type e.g, max_length, longest. max_length: Max length value if padding is set to max_length or the labels must be truncated. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: A dict of pixel values tensor of the processed images and labels token ids and attention masks. @@ -133,9 +152,9 @@ def process_batch(self, data, padding=None, max_length=None): paths = data["image_path"] texts = data["label"] - tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors="torch") + tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors=return_tensors) - data["pixel_values"] = self.image_processor(paths, return_tensors="torch")["pixel_values"] + data["pixel_values"] = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] data["attention_mask"] = tokenized_inputs["attention_mask"] data["decoder_input_ids"] = shift_tokens_right( @@ -198,28 +217,30 @@ def _text_to_tensor(self, text): raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") return labels - def process_single(self, data): + def process_single(self, data, return_tensors=None): """ Process a single image-to-text OCR example. Args: data: A data example containing an image path and corresponding text. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: dict: Processed inputs with pixel values and text labels. """ path = data[self.image_field] text = data[self.text_field] - pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] + pixel_values = self.image_processor(path, return_tensors=return_tensors)["pixel_values"][0] labels = self._text_to_tensor(text) return {"pixel_values": pixel_values, "labels": labels} - def process_batch(self, data): + def process_batch(self, data, return_tensors=None): """ Process a batch of image-to-text OCR examples. Args: data: A batch of data examples containing image paths and corresponding texts. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: dict: Batch of processed inputs with pixel values and text labels. @@ -228,7 +249,7 @@ def process_batch(self, data): texts = data[self.text_field] # Process images in batch - pixel_values = self.image_processor(paths, return_tensors="torch")["pixel_values"] + pixel_values = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] # Process text labels in batch labels = [] @@ -251,13 +272,14 @@ def __init__(self, tokenizer, label_all_tokens=True, ignore_index=-100, max_leng self.max_length = max_length self.padding = padding - def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): + def _tokenize_and_align(self, tokens, labels, return_tensors=None, padding=None, max_length=None): """ Tokenize and align tokens and labels for sequence labeling tasks. Args: tokens: List of tokens (for single examples) or list of lists (for batches). labels: List of labels (for single examples) or list of lists (for batches). + return_tensors: The type of the returning tensors (list, torch, numpy) padding: Padding strategy for tokenization. max_length: Maximum sequence length to truncate/pad. @@ -275,7 +297,7 @@ def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): padding=padding, truncation=True, max_length=max_length, - return_tensors="torch" + return_tensors=return_tensors ) word_ids = tokenized_inputs["word_ids"] @@ -300,12 +322,13 @@ def _tokenize_and_align(self, tokens, labels, padding=None, max_length=None): tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) return tokenized_inputs - def process_single(self, data, padding=None, max_length=None): + def process_single(self, data, return_tensors=None, padding=None, max_length=None): """ Process a single example of sequence labeling data. Args: data: A single data example containing tokens and labels. + return_tensors: The type of the returning tensors (list, torch, numpy) padding: Padding strategy. max_length: Maximum sequence length. @@ -315,18 +338,25 @@ def process_single(self, data, padding=None, max_length=None): tokens = data["tokens"] labels = data["pos_tags"] - tokenized_inputs = self._tokenize_and_align([tokens], [labels], padding=padding, max_length=max_length) + tokenized_inputs = self._tokenize_and_align( + [tokens], + [labels], + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + ) data.update(tokenized_inputs) return data - def process_batch(self, data, padding=None, max_length=None): + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): """ Process a batch of sequence labeling examples. Args: data: A batch of examples, containing tokens and labels. + return_tensors: The type of the returning tensors (list, torch, numpy) padding: Padding strategy. max_length: Maximum sequence length. @@ -336,7 +366,13 @@ def process_batch(self, data, padding=None, max_length=None): tokens = data["tokens"] labels = data["pos_tags"] - tokenized_inputs = self._tokenize_and_align(tokens, labels, padding=padding, max_length=max_length) + tokenized_inputs = self._tokenize_and_align( + tokens, + labels, + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + ) data.update(tokenized_inputs) @@ -353,9 +389,9 @@ def __init__( feature_extractor, tokenizer, sampling_rate=16000, - audio_array_padding="longest", + audio_array_padding=None, max_audio_array_length=None, - labels_padding="longest", + labels_padding=None, labels_max_length=None, audio_field="audio", transcript_field="transcript", @@ -371,12 +407,13 @@ def __init__( self.audio_field = audio_field self.transcript_field = transcript_field - def process_single(self, data): + def process_single(self, data, return_tensors=None): """ Process a single speech recognition example. Args: data: A data example containing audio and its transcript. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: dict: Processed input features and labels. @@ -390,7 +427,7 @@ def process_single(self, data): sampling_rate=self.sampling_rate, padding=self.audio_array_padding, max_length=self.max_audio_array_length, - return_tensors="torch", + return_tensors=return_tensors, )["input_features"] # Tokenize the transcript @@ -398,7 +435,7 @@ def process_single(self, data): transcript, padding=self.labels_padding, max_length=self.labels_max_length, - return_tensors="torch", + return_tensors=return_tensors, ) data["input_features"] = input_features @@ -407,12 +444,13 @@ def process_single(self, data): return data - def process_batch(self, data): + def process_batch(self, data, return_tensors=None): """ Process a batch of speech recognition examples. Args: data: A batch of data examples containing audio arrays and their corresponding transcripts. + return_tensors: The type of the returning tensors (list, torch, numpy) Returns: dict: Batch of processed input features and labels. @@ -426,7 +464,7 @@ def process_batch(self, data): sampling_rate=self.sampling_rate, padding=self.audio_array_padding, max_length=self.max_audio_array_length, - return_tensors="torch", + return_tensors=return_tensors, )["input_features"] # Tokenize transcripts in batch @@ -434,7 +472,7 @@ def process_batch(self, data): transcripts, padding=self.labels_padding, max_length=self.labels_max_length, - return_tensors="torch", + return_tensors=return_tensors, ) data["input_features"] = input_features @@ -455,9 +493,18 @@ def __init__(self, tokenizer, max_length=None, padding=None): self.padding = padding self.max_length = max_length - def process_single(self, data, padding=None, max_length=None): + def process_single(self, data, return_tensors=None, padding=None, max_length=None): """ - Process a single example. + Process a single example for text classification. + + Args: + data: A single data example dict + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Token ids padding type + max_length: Max input length + + Returns: + The updated data dictionary """ padding = padding or self.padding max_length = max_length or self.max_length @@ -467,20 +514,28 @@ def process_single(self, data, padding=None, max_length=None): inputs = self.tokenizer( text, - return_tensors="torch", - truncation="longest_first", padding=padding, max_length=max_length, return_attention_mask=True, + return_tensors=return_tensors, ) data.update(inputs) data["labels"] = torch.tensor([label], dtype=torch.long) return data - def process_batch(self, data, padding=None, max_length=None): + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): """ - Process a batch of examples. + Process a batch of examples for text classification. + + Args: + data: A single data example dict + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Token ids padding type + max_length: Max input length + + Returns: + The updated data dictionary """ padding = padding or self.padding max_length = max_length or self.max_length @@ -490,11 +545,10 @@ def process_batch(self, data, padding=None, max_length=None): inputs = self.tokenizer( texts, - return_tensors="torch", - truncation=True, padding=padding, max_length=max_length, return_attention_mask=True, + return_tensors=return_tensors, ) data.update(inputs) data["labels"] = torch.tensor(labels, dtype=torch.long) @@ -526,12 +580,13 @@ def __init__( self.summary_field = summary_field self.padding = padding - def process_single(self, data, padding=None, max_length=None, labels_max_length=None): + def process_single(self, data, return_tensors=None, padding=None, max_length=None, labels_max_length=None): """ Process a single example for text summarization. Args: data: A data example containing text and summary. + return_tensors: The type of the returning tensors (list, torch, numpy) padding: Padding strategy. max_length: Max length for input text. labels_max_length: Max length for summary labels. @@ -553,31 +608,30 @@ def process_single(self, data, padding=None, max_length=None, labels_max_length= # Tokenize inputs and labels inputs = self.tokenizer( text, - return_tensors="torch", max_length=max_length, padding=padding, return_attention_mask=True, - truncation=True + return_tensors=return_tensors, ) labels = self.tokenizer( summary, - return_tensors="torch", max_length=labels_max_length, padding=padding, return_attention_mask=True, - truncation=True + return_tensors=return_tensors, ) inputs["labels"] = labels["token_ids"].clone() return inputs - def process_batch(self, data, padding=None, max_length=None, labels_max_length=None): + def process_batch(self, data, return_tensors=None, padding=None, max_length=None, labels_max_length=None): """ Process a batch of examples for text summarization. Args: data: A batch of examples containing texts and summaries. + return_tensors: The type of the returning tensors (list, torch, numpy) padding: Padding strategy. max_length: Max length for input texts. labels_max_length: Max length for summary labels. @@ -599,19 +653,17 @@ def process_batch(self, data, padding=None, max_length=None, labels_max_length=N # Tokenize inputs and labels in batch inputs = self.tokenizer( texts, - return_tensors="torch", max_length=max_length, padding=padding, return_attention_mask=True, - truncation=True + return_tensors=return_tensors, ) labels = self.tokenizer( summaries, - return_tensors="torch", max_length=labels_max_length, padding=padding, return_attention_mask=True, - truncation=True + return_tensors=return_tensors, ) inputs["labels"] = labels["token_ids"].clone() From fc5c52d0559b1a215110375fa4e7f2fe16284e42 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Tue, 22 Oct 2024 21:00:39 +0330 Subject: [PATCH 14/33] :pencil2: Temp dataset processing example --- examples/data/dataset_processing_example.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/data/dataset_processing_example.py b/examples/data/dataset_processing_example.py index fde04553..a02958c3 100644 --- a/examples/data/dataset_processing_example.py +++ b/examples/data/dataset_processing_example.py @@ -117,19 +117,28 @@ # print(processed_dataset[0]) from datasets import load_dataset +from torch.utils.data import DataLoader -from hezar.data import SpeechRecognitionDatasetProcessor +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator from hezar.preprocessors import Tokenizer, AudioFeatureExtractor -dataset = load_dataset("hezarai/common-voice-13-fa", split="train[:100]") +dataset = load_dataset("parquet", split="train", data_files=["train-00001-of-00002.parquet"]).select(list(range(100))) +dataset = dataset.select_columns(["sentence", "audio"]) tokenizer = Tokenizer.load("hezarai/whisper-small-fa") feature_extractor = AudioFeatureExtractor.load("hezarai/whisper-small-fa") dataset_processor = SpeechRecognitionDatasetProcessor( tokenizer=tokenizer, feature_extractor=feature_extractor, transcript_field="sentence", + labels_padding=None, + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + labels_padding="max_length", + labels_max_length=256, ) - processed_dataset = dataset.map( dataset_processor, batched=True, @@ -138,5 +147,8 @@ # num_proc=10, desc="Processing dataset..." ) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) processed_dataset.set_format("torch") -print(processed_dataset[0]) \ No newline at end of file +data_loader = DataLoader(processed_dataset, batch_size=16, collate_fn=data_collator) +x = next(iter(data_loader)) +print(x) From 79687c9f28a3aeade01625966561f6175ce067d1 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 25 Oct 2024 17:15:40 +0330 Subject: [PATCH 15/33] :bug: Fix speech recognition data collator bug --- hezar/data/data_collators.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index 1a685c6d..4ba8d16a 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import numpy as np import torch @@ -189,6 +191,7 @@ class ImageCaptioningDataCollator: this value as length. return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ + def __init__( self, tokenizer: Tokenizer, @@ -249,17 +252,17 @@ def __init__( def __call__(self, input_batch): input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] - inputs = {} - for key in input_batch[0].keys(): - stack = [e for item in input_batch for e in item[key]] - inputs[key] = stack - + inputs = defaultdict(list) + for item in input_batch: + for key, value in item.items(): + inputs[key].append(value) + inputs = dict(inputs) inputs = self.tokenizer.pad_encoded_batch( inputs, padding=self.labels_padding, max_length=self.labels_max_length, exclude_keys=["input_features"], - return_tensors="torch" + return_tensors="torch", ) inputs = self.feature_extractor.pad( @@ -279,7 +282,7 @@ class SequenceLabelingDataCollator: Args: tokenizer (Tokenizer): A Hezar tokenizer instance. padding (str): Specifies padding strategy, either `longest` or `max_length`. - padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` + padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right`. label_pad_token_id (int): Token ID for padding labels. max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. From b9a2efb620feb746c7fb6d2e3d7fea510d0e6fa7 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 25 Oct 2024 17:15:59 +0330 Subject: [PATCH 16/33] :bug: Fix OCR data collator bug --- hezar/data/data_collators.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index 4ba8d16a..be1e6853 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -367,12 +367,15 @@ def __call__(self, input_batch): """ if isinstance(input_batch, (list, tuple)) and isinstance(input_batch[0], dict): input_batch = {key: [example[key] for example in input_batch] for key in input_batch[0].keys()} - input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"], 0) + + if not isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.tensor(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"], list) and isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"]) max_length = max(map(len, input_batch["labels"])) all_labels = [] for labels in input_batch["labels"]: - labels = labels.numpy().tolist() labels += [self.pad_token_id] * (max_length - len(labels)) all_labels.append(labels) input_batch["labels"] = torch.tensor(all_labels) From 7265263a1eda20a30d98c63524a35d81b769b848 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 25 Oct 2024 17:19:22 +0330 Subject: [PATCH 17/33] :sparkles: Return non-batched output when input is not batched in `Tokenizer` and `ImageProcessor` Only applies when `return_tensors`="list" --- hezar/preprocessors/image_processor.py | 11 ++++++++--- hezar/preprocessors/tokenizers/tokenizer.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/hezar/preprocessors/image_processor.py b/hezar/preprocessors/image_processor.py index 3665966a..4d30d16a 100644 --- a/hezar/preprocessors/image_processor.py +++ b/hezar/preprocessors/image_processor.py @@ -88,7 +88,7 @@ def __init__(self, config: ImageProcessorConfig, **kwargs): def __call__( self, - images: List, + images: str | List, device: str = None, mean: float = None, std: float = None, @@ -104,7 +104,7 @@ def __call__( Perform sequential image processing on a list of input images. Args: - images (List): A list of input images of types torch, numpy, pillow. + images (str | List): A list of input images (torch, numpy, pillow) OR path or list of paths to images. mean (float): Image mean value for normalization. std (float): Image std value for normalization. rescale (float): Scale factor for rescaling the image. @@ -126,8 +126,10 @@ def __call__( mirror = mirror or self.config.mirror gray_scale = gray_scale or self.config.gray_scale + is_single = False if not isinstance(images, list) or isinstance(images, str) or isinstance(images, np.ndarray): images = [images] + is_single = True # Load images if inputs are list of files images = [load_image(x, return_type="numpy") if isinstance(x, str) else x for x in images] @@ -162,11 +164,14 @@ def __call__( images = convert_batch_dict_dtype({"pixel_values": images}, dtype=return_tensors) - if device: + if device and return_tensors == "torch": import torch images = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in images.items()} + if is_single and return_tensors == "list": + images = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in images.items()} + return images @classmethod diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 686e4798..e9212e1f 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -232,7 +232,8 @@ def pad_encoded_batch( padding=padding, padding_side=self.config.padding_side, pad_id=pad_id, - max_length=max_length, truncation=truncation, + max_length=max_length, + truncation=truncation, ) inputs[key] = padded_ids @@ -303,8 +304,13 @@ def __call__( " This warning will change to an error in the future!" ) + return_tensors = return_tensors or "list" + if isinstance(inputs, str): inputs = [inputs] + is_single = True + else: + is_single = False if padding is None and max_length is not None: padding = PaddingType.MAX_LENGTH @@ -358,6 +364,10 @@ def __call__( if device and return_tensors == "torch": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} + # Squeeze tensor if the original input is a single string and return_tensors is `list` + if is_single and return_tensors == "list": + outputs = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in outputs.items()} + return outputs def set_truncation_and_padding( From e04da6eb75d664ed377a406fcdb812726c69ff32 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 25 Oct 2024 17:20:07 +0330 Subject: [PATCH 18/33] :pencil2: Minor renamings --- hezar/data/dataset_processors.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 7106ff86..7c90d304 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -8,7 +8,7 @@ >>> from datasets import load_dataset >>> from hezar.data import SpeechRecognitionDatasetProcessor ->>> data_processor = SpeechRecognitionDatasetProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) +>>> data_processor = SpeechRecognitionDatasetProcessor(feature_extractor=feature_extractor,tokenizer=tokenizer) >>> dataset = load_dataset("hezarai/common-voice-13-fa") >>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) """ @@ -192,7 +192,7 @@ def __init__( self.image_field = image_field self.text_field = text_field - def _text_to_tensor(self, text): + def _text_to_ids(self, text): """ Convert text to tensor based on the configured text_split_type. @@ -205,14 +205,12 @@ def _text_to_tensor(self, text): """ if self.text_split_type == "tokenize": token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"] - token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] - labels = torch.tensor(token_ids) + labels = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] elif self.text_split_type == "char_split": if self.reverse_digits: text = reverse_string_digits(text) label2id = {v: k for k, v in self.id2label.items()} labels = [label2id[char] for char in text] - labels = torch.LongTensor(labels) else: raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") return labels @@ -231,7 +229,7 @@ def process_single(self, data, return_tensors=None): path = data[self.image_field] text = data[self.text_field] pixel_values = self.image_processor(path, return_tensors=return_tensors)["pixel_values"][0] - labels = self._text_to_tensor(text) + labels = self._text_to_ids(text) return {"pixel_values": pixel_values, "labels": labels} def process_batch(self, data, return_tensors=None): @@ -254,7 +252,7 @@ def process_batch(self, data, return_tensors=None): # Process text labels in batch labels = [] for text in texts: - labels.append(self._text_to_tensor(text)) + labels.append(self._text_to_ids(text)) return {"pixel_values": pixel_values, "labels": labels} @@ -393,8 +391,8 @@ def __init__( max_audio_array_length=None, labels_padding=None, labels_max_length=None, - audio_field="audio", - transcript_field="transcript", + audio_column="audio", + transcript_column="transcript", ): super().__init__() self.feature_extractor = feature_extractor @@ -404,8 +402,8 @@ def __init__( self.max_audio_array_length = max_audio_array_length self.labels_padding = labels_padding self.labels_max_length = labels_max_length - self.audio_field = audio_field - self.transcript_field = transcript_field + self.audio_column = audio_column + self.transcript_column = transcript_column def process_single(self, data, return_tensors=None): """ @@ -418,8 +416,8 @@ def process_single(self, data, return_tensors=None): Returns: dict: Processed input features and labels. """ - audio_array = data[self.audio_field]["array"] - transcript = data[self.transcript_field] + audio_array = data[self.audio_column]["array"] + transcript = data[self.transcript_column] # Extract input features from audio input_features = self.feature_extractor( @@ -455,8 +453,8 @@ def process_batch(self, data, return_tensors=None): Returns: dict: Batch of processed input features and labels. """ - audio_arrays = [x["array"] for x in data[self.audio_field]] - transcripts = data[self.transcript_field] + audio_arrays = [x["array"] for x in data[self.audio_column]] + transcripts = data[self.transcript_column] # Extract input features in batch input_features = self.feature_extractor( From 95b566cdbf9437dbe32e5228b64a54a0538fa088 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 25 Oct 2024 17:22:53 +0330 Subject: [PATCH 19/33] :sparkles: Return list objects in datasets when tokenizing by default --- hezar/data/datasets/ocr_dataset.py | 6 ++---- hezar/data/datasets/speech_recognition_dataset.py | 5 ++--- hezar/data/datasets/text_classification_dataset.py | 11 ++--------- hezar/data/datasets/text_summarization_dataset.py | 8 +++----- 4 files changed, 9 insertions(+), 21 deletions(-) diff --git a/hezar/data/datasets/ocr_dataset.py b/hezar/data/datasets/ocr_dataset.py index e0d05110..c11fb154 100644 --- a/hezar/data/datasets/ocr_dataset.py +++ b/hezar/data/datasets/ocr_dataset.py @@ -123,7 +123,7 @@ def _load(self, split=None): data = data.select(valid_indices) return data - def _text_to_tensor(self, text): + def _text_to_ids(self, text): """ Convert text to tensor based on the configured text_split_type. @@ -139,14 +139,12 @@ def _text_to_tensor(self, text): token_ids = self.tokenizer(text, padding="max_length", max_length=self.config.max_length)["token_ids"] # Make sure to ignore pad tokens by the loss function token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] - labels = torch.tensor(token_ids) # If split text is not tokenizer-based elif self.config.text_split_type == TextSplitType.CHAR_SPLIT: if self.config.reverse_digits: text = reverse_string_digits(text) label2id = {v: k for k, v in self.config.id2label.items()} labels = [label2id[x] for x in text] - labels = torch.LongTensor(labels) else: raise ValueError(f"Invalid `text_split_type={self.config.text_split_type}`") @@ -165,7 +163,7 @@ def __getitem__(self, index): """ path, text = self.data[index].values() pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] - labels = self._text_to_tensor(text) + labels = self._text_to_ids(text) inputs = { "pixel_values": pixel_values, "labels": labels, diff --git a/hezar/data/datasets/speech_recognition_dataset.py b/hezar/data/datasets/speech_recognition_dataset.py index d8a5c3f2..1d85c3e1 100644 --- a/hezar/data/datasets/speech_recognition_dataset.py +++ b/hezar/data/datasets/speech_recognition_dataset.py @@ -62,14 +62,13 @@ def __getitem__(self, index): input_features = self.feature_extractor( audio_array, sampling_rate=self.config.sampling_rate, - return_tensors="torch" - )["input_features"] + return_tensors="numpy" + )["input_features"][0] labels = self.tokenizer( transcript, padding=self.config.labels_padding, max_length=self.config.labels_max_length, - return_tensors="torch", ) return { diff --git a/hezar/data/datasets/text_classification_dataset.py b/hezar/data/datasets/text_classification_dataset.py index b27ad858..6e8a5c9d 100644 --- a/hezar/data/datasets/text_classification_dataset.py +++ b/hezar/data/datasets/text_classification_dataset.py @@ -94,14 +94,7 @@ def __getitem__(self, index): """ text = self.data[index][self.config.text_field] label = self.data[index][self.config.label_field] - inputs = self.tokenizer( - text, - return_tensors="torch", - truncation=True, - padding="longest", - return_attention_mask=True, - ) - label_idx = torch.tensor([label], dtype=torch.long) # noqa - inputs["labels"] = label_idx + inputs = self.tokenizer(text, return_attention_mask=True) + inputs["labels"] = label return inputs diff --git a/hezar/data/datasets/text_summarization_dataset.py b/hezar/data/datasets/text_summarization_dataset.py index aaa7be91..2dfd4f0f 100644 --- a/hezar/data/datasets/text_summarization_dataset.py +++ b/hezar/data/datasets/text_summarization_dataset.py @@ -59,7 +59,7 @@ def __init__(self, config: TextSummarizationDatasetConfig, split=None, preproces tokenizer=self.tokenizer, max_length=self.config.max_length, labels_max_length=self.config.labels_max_length, - padding="max_length" if self.config.max_length else "longest", + padding="max_length" if self.config.max_length else None, ) def _load(self, split): @@ -94,16 +94,14 @@ def __getitem__(self, index): inputs = self.tokenizer( text, - return_tensors="torch", max_length=self.config.max_length, - padding="max_length" if self.config.max_length else "longest", + padding="max_length" if self.config.max_length else None, return_attention_mask=True, ) labels = self.tokenizer( summary, - return_tensors="torch", max_length=self.config.max_length, - padding="max_length" if self.config.labels_max_length else "longest", + padding="max_length" if self.config.labels_max_length else None, return_attention_mask=True, ) From 57870ed24140c2a108611de6c5dbaee9ce3b3d96 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 10:46:46 +0330 Subject: [PATCH 20/33] :bug: Handle batched inputs better in Tokenizer call --- hezar/preprocessors/tokenizers/tokenizer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index e9212e1f..468d99ac 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -306,11 +306,12 @@ def __call__( return_tensors = return_tensors or "list" - if isinstance(inputs, str): + # Convert to batch if input is a single string or a list of words (is split into words for sequence labeling) + if isinstance(inputs, str) or (is_split_into_words and not isinstance(inputs[0], list)): inputs = [inputs] - is_single = True + is_batch = False else: - is_single = False + is_batch = True if padding is None and max_length is not None: padding = PaddingType.MAX_LENGTH @@ -360,14 +361,18 @@ def __call__( overflow_to_sample_mapping += [i] * len(encodings_["input_ids"]) sanitized_outputs["overflow_to_sample_mapping"] = overflow_to_sample_mapping + # Squeeze tensor if the original input is a single string and return_tensors is `list` + if (return_tensors == "list" or return_tensors is None) and not is_batch: + sanitized_outputs = { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in sanitized_outputs.items() + } + outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys) + if device and return_tensors == "torch": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} - # Squeeze tensor if the original input is a single string and return_tensors is `list` - if is_single and return_tensors == "list": - outputs = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in outputs.items()} - return outputs def set_truncation_and_padding( From f37a80e821bb8e779a83a555aea13f2a073d9e9d Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 11:38:05 +0330 Subject: [PATCH 21/33] :fire: Remove unnecessary unbatching in data collators --- hezar/data/data_collators.py | 41 ++---------------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index be1e6853..573da3fb 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -54,14 +54,6 @@ def __init__( "attention_mask": 0, } - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None - def __call__(self, encoded_batch): """ Add padding to every item in the batch @@ -73,12 +65,7 @@ def __call__(self, encoded_batch): Dict: The same batch dictionary but padded """ encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack - encoded_batch = permuted_batch.copy() if "label" in encoded_batch: encoded_batch["labels"] = encoded_batch["label"] del encoded_batch["label"] @@ -137,14 +124,6 @@ def __init__( self.labels_max_length = labels_max_length self.return_tensors = return_tensors - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None - def __call__(self, encoded_batch): """ Add padding to every item in the batch @@ -156,13 +135,9 @@ def __call__(self, encoded_batch): Dict: The same batch dictionary but padded """ encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack padded_batch = self.tokenizer.pad_encoded_batch( - permuted_batch, + encoded_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["labels"], @@ -206,23 +181,11 @@ def __init__( self.max_length = max_length self.return_tensors = return_tensors - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None - def __call__(self, encoded_batch): encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack padded_batch = self.tokenizer.pad_encoded_batch( - permuted_batch, + encoded_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["pixel_values"], From 2c32b675b0a2d3bb57f4019456f976c8f304f9ac Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 11:38:28 +0330 Subject: [PATCH 22/33] :pencil2: Minor changes in dataset processors --- hezar/data/dataset_processors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 7c90d304..5af1738f 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -567,7 +567,7 @@ def __init__( labels_max_length=None, text_field="text", summary_field="summary", - padding="longest", + padding=None, ): super().__init__() self.tokenizer = tokenizer @@ -619,7 +619,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non return_tensors=return_tensors, ) - inputs["labels"] = labels["token_ids"].clone() + inputs["labels"] = labels["token_ids"] return inputs @@ -664,6 +664,6 @@ def process_batch(self, data, return_tensors=None, padding=None, max_length=None return_tensors=return_tensors, ) - inputs["labels"] = labels["token_ids"].clone() + inputs["labels"] = labels["token_ids"] return inputs From 4cdf79db3bd2c652da410e60c4f1fe4b42b35065 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 11:39:11 +0330 Subject: [PATCH 23/33] :fire: Deprecate setting `max_length`, `padding`, `truncation` in the Tokenizer config --- hezar/preprocessors/tokenizers/tokenizer.py | 29 ++++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 468d99ac..c5c1e36b 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -61,13 +61,13 @@ class TokenizerConfig(PreprocessorConfig): """ name = "tokenizer" - max_length: int = None - truncation: str = None + max_length: int = "deprecated" + truncation: str = "deprecated" truncation_side: str = None - padding: str = None + padding: str = "deprecated" padding_side: str = None stride: int = None - pad_to_multiple_of: int = None + pad_to_multiple_of: int = "deprecated" pad_token_type_id: int = 0 bos_token: str = None eos_token: str = None @@ -78,6 +78,21 @@ class TokenizerConfig(PreprocessorConfig): mask_token: str = None additional_special_tokens: List[str] = None + def __post_init__(self): + super().__post_init__() + if self.max_length != "deprecated": + logger.warning( + "Setting `max_length` in the tokenizer config is deprecated and will be removed in the future!" + ) + if self.padding != "deprecated": + logger.warning( + "Setting `padding` in the tokenizer config is deprecated and will be removed in the future!" + ) + if self.truncation != "deprecated": + logger.warning( + "Setting `truncation` in the tokenizer config is deprecated and will be removed in the future!" + ) + class Tokenizer(Preprocessor): """ @@ -313,12 +328,6 @@ def __call__( else: is_batch = True - if padding is None and max_length is not None: - padding = PaddingType.MAX_LENGTH - truncation = truncation or self.config.truncation - max_length = max_length or self.config.max_length - pad_to_multiple_of = pad_to_multiple_of or self.config.pad_to_multiple_of - self.set_truncation_and_padding( padding=padding, truncation=truncation, From 6e2e3fa01a5b7c97c44ed0b6b0a602b3e90fcdcb Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 15:58:20 +0330 Subject: [PATCH 24/33] :adhesive_bandage: Make batch conversions cleaner in data collators --- hezar/data/data_collators.py | 101 +++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 47 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index 573da3fb..c36d0648 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -19,6 +19,24 @@ logger = Logger(__name__) +def _convert_to_batch_dict(dicts_list: list[dict]): + """ + Convert a list of dicts to a dict of batched values. + + Args: + dicts_list: A list of dictionaries containing the same set of keys + + Returns: + A dictionary of the batches + """ + batch_dict = defaultdict(list) + for item in dicts_list: + for key, value in item.items(): + batch_dict[key].append(value) + batch_dict = dict(batch_dict) + return batch_dict + + class TextPaddingDataCollator: """ A data collator that pads a batch of tokenized inputs. @@ -54,26 +72,22 @@ def __init__( "attention_mask": 0, } - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch: A batch dictionary + input_batch: A batch dictionary Returns: Dict: The same batch dictionary but padded """ - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - - if "label" in encoded_batch: - encoded_batch["labels"] = encoded_batch["label"] - del encoded_batch["label"] - - labels = encoded_batch.pop("labels") - input_length = self.max_length or max(len(x) for x in encoded_batch["token_ids"]) + input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] + input_batch = _convert_to_batch_dict(input_batch) + labels = input_batch.pop("labels") + input_length = self.max_length or max(len(x) for x in input_batch["token_ids"]) - for field, batch in encoded_batch.items(): + for field, batch in input_batch.items(): padded_batch = [] for x in batch: if isinstance(x, torch.Tensor): @@ -84,13 +98,13 @@ def __call__(self, encoded_batch): paddings = [self.field_to_pad_id_mapping[field]] * difference padded_x = x + paddings if self.padding_side == "right" else paddings + x padded_batch.append(padded_x) - encoded_batch[field] = padded_batch + input_batch[field] = padded_batch - encoded_batch["labels"] = labels + input_batch["labels"] = labels - encoded_batch = convert_batch_dict_dtype(encoded_batch, dtype=self.return_tensors) + input_batch = convert_batch_dict_dtype(input_batch, dtype=self.return_tensors) - return encoded_batch + return input_batch class TextGenerationDataCollator: @@ -124,20 +138,20 @@ def __init__( self.labels_max_length = labels_max_length self.return_tensors = return_tensors - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch (List[Dict]): A batch dictionary + input_batch (List[Dict]): A batch dictionary Returns: Dict: The same batch dictionary but padded """ - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - + input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] + input_batch = _convert_to_batch_dict(input_batch) padded_batch = self.tokenizer.pad_encoded_batch( - encoded_batch, + input_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["labels"], @@ -181,11 +195,11 @@ def __init__( self.max_length = max_length self.return_tensors = return_tensors - def __call__(self, encoded_batch): - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - + def __call__(self, input_batch): + input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] + input_batch = _convert_to_batch_dict(input_batch) padded_batch = self.tokenizer.pad_encoded_batch( - encoded_batch, + input_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["pixel_values"], @@ -215,13 +229,9 @@ def __init__( def __call__(self, input_batch): input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] - inputs = defaultdict(list) - for item in input_batch: - for key, value in item.items(): - inputs[key].append(value) - inputs = dict(inputs) + input_batch = _convert_to_batch_dict(input_batch) inputs = self.tokenizer.pad_encoded_batch( - inputs, + input_batch, padding=self.labels_padding, max_length=self.labels_max_length, exclude_keys=["input_features"], @@ -268,43 +278,41 @@ def __init__( self.max_length = max_length self.return_tensors = return_tensors - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch (List[Dict]): A batch dictionary + input_batch (List[Dict]): A batch dictionary Returns: Dict: The same batch dictionary but padded """ - label_name = "label" if "label" in encoded_batch[0].keys() else "labels" - labels = [feature[label_name] for feature in encoded_batch] if label_name in encoded_batch[0].keys() else None + input_batch = _convert_to_batch_dict(input_batch) + labels = input_batch["labels"] self.tokenizer.config.padding_side = self.padding_side - batch = self.tokenizer.pad_encoded_batch( - encoded_batch, + input_batch = self.tokenizer.pad_encoded_batch( + input_batch, padding=self.padding, # noqa max_length=self.max_length, - # Conversion to tensors will fail if we have labels as they are not of the same length yet. - return_tensors="torch" if labels is None else None, ) if labels is None: - return batch + return input_batch - batch.pop("word_ids", None) - sequence_length = torch.tensor(batch["token_ids"]).shape[1] + input_batch.pop("word_ids", None) + sequence_length = torch.tensor(input_batch["token_ids"]).shape[1] if self.padding_side == "right": - batch[label_name] = [ + input_batch["labels"] = [ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] else: - batch[label_name] = [ + input_batch["labels"] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels ] - batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} - return batch + input_batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in input_batch.items()} + return input_batch class CharLevelOCRDataCollator: @@ -328,8 +336,7 @@ def __call__(self, input_batch): Returns: Dict: Padded input batch. """ - if isinstance(input_batch, (list, tuple)) and isinstance(input_batch[0], dict): - input_batch = {key: [example[key] for example in input_batch] for key in input_batch[0].keys()} + input_batch = _convert_to_batch_dict(input_batch) if not isinstance(input_batch["pixel_values"][0], torch.Tensor): input_batch["pixel_values"] = torch.tensor(input_batch["pixel_values"]) From a1bd4c24efadc2fb3ad91859eadd31ae1e754ae7 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 16:00:16 +0330 Subject: [PATCH 25/33] :adhesive_bandage: Fix truncation issue when `max_length` is None in `Tokenizer` --- hezar/preprocessors/tokenizers/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index c5c1e36b..15be87ff 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -395,7 +395,7 @@ def set_truncation_and_padding( pad_to_multiple_of: int = None, ): # Set truncation and padding on the backend tokenizer - if truncation == "no_truncation" or truncation is None: + if truncation == "no_truncation" or truncation is None or max_length is None: if self.truncation is not None: self.no_truncation() else: From 6288c07b9c41be3d39dcd42659e3a7b7f76a774f Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 16:00:51 +0330 Subject: [PATCH 26/33] :adhesive_bandage: Fix some issues in dataset processors --- hezar/data/dataset_processors.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 5af1738f..57b49c0f 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -12,10 +12,11 @@ >>> dataset = load_dataset("hezarai/common-voice-13-fa") >>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) """ + import torch from ..constants import Backends -from ..utils import is_backend_available, reverse_string_digits, shift_tokens_right, verify_dependencies +from ..utils import is_backend_available, reverse_string_digits, verify_dependencies if is_backend_available(Backends.DATASETS): @@ -101,6 +102,25 @@ def __init__(self, image_processor, tokenizer, max_length=None, padding=None): self.max_length = max_length self.padding = padding + @staticmethod + def _shift_tokens_right(input_ids: list[list[int]], pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + # Initialize shifted_input_ids as a list of lists with the same shape as input_ids + shifted_input_ids = [[0] * len(row) for row in input_ids] + + for i, row in enumerate(input_ids): + # Shift each row one token to the right + shifted_input_ids[i][1:] = row[:-1] + # Set the first token to decoder_start_token_id + shifted_input_ids[i][0] = decoder_start_token_id + + # Replace any -100 values with pad_token_id + shifted_input_ids[i] = [pad_token_id if token == -100 else token for token in shifted_input_ids[i]] + + return shifted_input_ids + def process_single(self, data, return_tensors=None, padding=None, max_length=None): """ Process image and tokenize captions for a single data sample. @@ -125,7 +145,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non data["pixel_values"] = self.image_processor(path, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] data["attention_mask"] = tokenized_inputs["attention_mask"] - data["decoder_input_ids"] = shift_tokens_right( + data["decoder_input_ids"] = self._shift_tokens_right( data["labels"], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, @@ -157,7 +177,7 @@ def process_batch(self, data, return_tensors=None, padding=None, max_length=None data["pixel_values"] = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] data["attention_mask"] = tokenized_inputs["attention_mask"] - data["decoder_input_ids"] = shift_tokens_right( + data["decoder_input_ids"] = self._shift_tokens_right( data["labels"], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, @@ -317,7 +337,7 @@ def _tokenize_and_align(self, tokens, labels, return_tensors=None, padding=None, previous_word_idx = word_idx aligned_labels.append(label_ids) - tokenized_inputs["labels"] = torch.tensor(aligned_labels, dtype=torch.long) + tokenized_inputs["labels"] = aligned_labels return tokenized_inputs def process_single(self, data, return_tensors=None, padding=None, max_length=None): From 1e54f810f4faf0b099fa600a60e5e328b05ecbd1 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 31 Oct 2024 20:22:11 +0330 Subject: [PATCH 27/33] :adhesive_bandage: Return `decoder_attention_mask` instead of `attention_mask` in image captioning dataset processor --- hezar/data/dataset_processors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 57b49c0f..3d32f084 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -176,7 +176,7 @@ def process_batch(self, data, return_tensors=None, padding=None, max_length=None data["pixel_values"] = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] - data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_attention_mask"] = tokenized_inputs["attention_mask"] data["decoder_input_ids"] = self._shift_tokens_right( data["labels"], pad_token_id=self.tokenizer.pad_token_id, From d1515a68d4742a97609c3f7d5bc358cf9640a3ae Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 7 Nov 2024 20:20:59 +0330 Subject: [PATCH 28/33] :pencil2: Fix issues in data collators --- hezar/data/data_collators.py | 37 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index c36d0648..3e954968 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -118,7 +118,6 @@ class TextGenerationDataCollator: max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. labels_max_length (int): Maximum target length for text generation. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ @@ -129,14 +128,12 @@ def __init__( padding_side: str = "right", max_length: int = None, labels_max_length: int = None, - return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.max_length = max_length self.labels_max_length = labels_max_length - self.return_tensors = return_tensors def __call__(self, input_batch): """ @@ -155,14 +152,14 @@ def __call__(self, input_batch): padding=self.padding, max_length=self.max_length, exclude_keys=["labels"], - return_tensors=self.return_tensors, + return_tensors="torch", ) padded_batch = self.tokenizer.pad_encoded_batch( padded_batch, padding=self.padding, max_length=self.labels_max_length, include_keys=["labels"], - return_tensors=self.return_tensors, + return_tensors="torch", ) return padded_batch @@ -178,7 +175,6 @@ class ImageCaptioningDataCollator: padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ def __init__( @@ -187,27 +183,30 @@ def __init__( padding: str = "longest", padding_side: str = "right", max_length: int = None, - return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.max_length = max_length - self.return_tensors = return_tensors def __call__(self, input_batch): - input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] input_batch = _convert_to_batch_dict(input_batch) - padded_batch = self.tokenizer.pad_encoded_batch( + input_batch = self.tokenizer.pad_encoded_batch( input_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["pixel_values"], - return_tensors=self.return_tensors, + return_tensors="torch", ) - padded_batch = convert_batch_dict_dtype(padded_batch, dtype="torch") + if isinstance(input_batch["pixel_values"], list): + if isinstance(input_batch["pixel_values"][0], list): + input_batch["pixel_values"] = torch.tensor(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"][0], np.ndarray): + input_batch["pixel_values"] = torch.stack([torch.from_numpy(x) for x in input_batch["pixel_values"]]) - return padded_batch + return input_batch class SpeechRecognitionDataCollator: @@ -228,7 +227,6 @@ def __init__( self.labels_max_length = labels_max_length def __call__(self, input_batch): - input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] input_batch = _convert_to_batch_dict(input_batch) inputs = self.tokenizer.pad_encoded_batch( input_batch, @@ -259,7 +257,6 @@ class SequenceLabelingDataCollator: label_pad_token_id (int): Token ID for padding labels. max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ def __init__( @@ -269,14 +266,12 @@ def __init__( padding_side: str = "right", label_pad_token_id: int = -100, max_length: int = None, - return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.label_pad_token_id = label_pad_token_id self.max_length = max_length - self.return_tensors = return_tensors def __call__(self, input_batch): """ @@ -295,13 +290,14 @@ def __call__(self, input_batch): input_batch, padding=self.padding, # noqa max_length=self.max_length, + return_tensors="torch", ) if labels is None: return input_batch input_batch.pop("word_ids", None) - sequence_length = torch.tensor(input_batch["token_ids"]).shape[1] + sequence_length = input_batch["token_ids"].shape[1] if self.padding_side == "right": input_batch["labels"] = [ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels @@ -311,7 +307,10 @@ def __call__(self, input_batch): [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels ] - input_batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in input_batch.items()} + input_batch = { + k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in input_batch.items() + } + return input_batch From f6ae4266346ce3461cfa0f182f2330a7b17c6fef Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 7 Nov 2024 20:21:34 +0330 Subject: [PATCH 29/33] :sparkles: Make `shift_tokens_right` compatible with all tensor types --- hezar/utils/data_utils.py | 55 ++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/hezar/utils/data_utils.py b/hezar/utils/data_utils.py index 2c714435..bf04001a 100644 --- a/hezar/utils/data_utils.py +++ b/hezar/utils/data_utils.py @@ -1,17 +1,16 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional +import numpy as np +import torch from omegaconf import DictConfig from ..constants import PaddingType from .logging import Logger -if TYPE_CHECKING: - import torch - __all__ = [ "convert_batch_dict_dtype", "resolve_inputs_length_for_padding", @@ -166,18 +165,50 @@ def pad_batch_items( return padded_inputs -def shift_tokens_right(input_ids: "torch.Tensor", pad_token_id: int, decoder_start_token_id: int): +def shift_tokens_right( + token_ids: list[list[int]] | "torch.Tensor" | "np.ndarray", + pad_token_id: int, + decoder_start_token_id: int +): """ Shift input ids one token to the right. """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id + # Check if input is a list of lists + if isinstance(token_ids, list): + # Initialize shifted_input_ids with the same shape as input_ids + shifted_input_ids = [[0] * len(row) for row in token_ids] + + for i, row in enumerate(token_ids): + # Shift each row one token to the right + shifted_input_ids[i][1:] = row[:-1] + # Set the first token to decoder_start_token_id + shifted_input_ids[i][0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids[i] = [pad_token_id if token == -100 else token for token in shifted_input_ids[i]] + return shifted_input_ids + + # Check if input is a NumPy array + elif isinstance(token_ids, np.ndarray): + # Initialize shifted_input_ids with zeros and the same shape as input_ids + shifted_input_ids = np.zeros_like(token_ids) + shifted_input_ids[:, 1:] = token_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + # Check if input is a PyTorch tensor + elif isinstance(token_ids, torch.Tensor): + # Initialize shifted_input_ids with zeros and the same shape as input_ids + shifted_input_ids = token_ids.new_zeros(token_ids.shape) + shifted_input_ids[:, 1:] = token_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids = shifted_input_ids.masked_fill(shifted_input_ids == -100, pad_token_id) + return shifted_input_ids - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids + else: + raise TypeError("Unsupported input type. Expected list, numpy array, or torch tensor.") def torch2numpy(*args): From e14d10636858bce68e464da4ee410aec347de294 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 7 Nov 2024 20:22:24 +0330 Subject: [PATCH 30/33] :pencil2: Return unbatched pixel values in `ImageCaptioningDataset` --- hezar/data/datasets/image_captioning_dataset.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/hezar/data/datasets/image_captioning_dataset.py b/hezar/data/datasets/image_captioning_dataset.py index 51e53fd4..29505828 100644 --- a/hezar/data/datasets/image_captioning_dataset.py +++ b/hezar/data/datasets/image_captioning_dataset.py @@ -77,19 +77,17 @@ def __getitem__(self, index): dict: The input data. """ path, text = self.data[index].values() - pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"] + pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] tokenized_inputs = self.tokenizer(text, padding="max_length", max_length=self.config.max_length) - labels = torch.tensor([tokenized_inputs["token_ids"]]) - attention_mask = torch.tensor([tokenized_inputs["attention_mask"]]) decoder_input_ids = shift_tokens_right( - labels, + [tokenized_inputs["token_ids"]], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, - ) + )[0] inputs = { "pixel_values": pixel_values, - "labels": labels, + "labels": tokenized_inputs["token_ids"], "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": attention_mask, + "decoder_attention_mask": tokenized_inputs["attention_mask"], } return inputs From 2b147ff6ef95e371ee7ccee9aa7eb8435c58cb08 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 8 Nov 2024 12:54:56 +0330 Subject: [PATCH 31/33] :memo: Add dataset processor guides in the docs --- docs/guide/dataset_processors.md | 186 +++++++++++++++++++++++++++++++ docs/guide/index.md | 1 + docs/tutorial/datasets.md | 91 ++++++++++++++- 3 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 docs/guide/dataset_processors.md diff --git a/docs/guide/dataset_processors.md b/docs/guide/dataset_processors.md new file mode 100644 index 00000000..6ba99c25 --- /dev/null +++ b/docs/guide/dataset_processors.md @@ -0,0 +1,186 @@ +# Dataset Processors +Initially, Hezar's Trainer worked only with PyTorch Datasets (derived from `torch.utils.data.Dataset`) like all Hezar datasets classes +at `hezar.data.datasets`. Moving on, we also added support for any iterable as the dataset in Hezar's Trainer. + +One really important type of datasets is 🤗 Datasets. The Trainer almost supported these type of datasets since day one, +but implementing the data pipelines must have been handled by the user. That's why Hezar (`v0.42.0>=`) added a new category of classes +called Dataset Processors. These classes are used as dataset map callables which has the following benefits: +- The same processing pipeline in the corresponding `hezar.data.Dataset` subclass is implemented as a map function. +For example, `SpeechRecognitionDatasetProcessor` corresponds to `SpeechRecognitionDataset`. +- Features like cacheing, multiprocessing, batch processing, etc. are now available since objects are of type `datasets.Dataset`. +- Other dataset processing pipelines from other codes feel like plug-and-play to work with Hezar's `Trainer`. + +Now lets see an example demonstrating both cases: + +**Classic 🤗Datasets** + +Here we need to implement a map function that processes our samples. 🤗Datasets `map` function works on callables that +operate on either single or batched inputs. Below is an implementation for batched processing: +```python +from datasets import load_dataset, Audio +from hezar.preprocessors import Preprocessor + + +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") +feature_extractor = preprocesssor.audio_feature_extractor +tokenizer = preprocesssor.tokenizer + +def batch_process_fn(data): + # Extract audio arrays and transcripts + audio_arrays = data["audio"] # Assuming audio arrays are stored under the "audio" key + transcripts = data["transcript"] # Assuming transcripts are stored under the "transcript" key + + # Extract input features in batch + input_features = feature_extractor( + audio_arrays, + sampling_rate=16000, + return_tensors="np", # Return as numpy for compatibility with map + )["input_features"] + + # Tokenize transcripts in batch + labels = tokenizer( + transcripts, + padding="max_length", + max_length=448, + return_tensors="np", + ) + + # Add processed data to the dictionary + data["input_features"] = input_features + data["labels"] = labels["input_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) +# Apply the function to the dataset using map +processed_dataset = dataset.map(batch_process_fn, batched=True) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +print(processed_dataset[0]) +``` + +**Hezar Dataset Processors** + +Here's an equivalent code using the `SpeechRecognitionDatasetProcessor` that has implemented the same map function as a +callable (`SpeechRecognitionDatasetProcessor.__call__()`) that works with both single and batched inputs out of the box! +```python +from datasets import load_dataset, Audio + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) + +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +print(processed_dataset[0]) +``` + +## How Dataset Processors Work +Dataset processors classes are callable classes that receive dataset rows/batches and process them when used as a map function +with `datasets.Dataset.map()`. Here are the current supported dataset processors: +- `ImageCaptioningDatasetProcessor` +- `OCRDatasetProcessor` +- `SequenceLabelingDatasetProcessor` +- `SpeechRecognitionDatasetProcessor` +- `TextClassificationDatasetProcessor` +- `TextSummarizationDatasetProcessor` + +All the above classes inherit from the base `DatasetProcessor` class and must implement the following two methods: +- `process_single(data, **kwargs)` +- `process_batch(data, **kwargs)` + +The main `__call__()` method is implemented in the base class to figure out if the input `data` is a single row or a batch. + + +## A Training Example +Let's see how we can use a dataset processor to load and process a Hub dataset for speech recognition and train a Whisper model. + +```python +from datasets import load_dataset, Audio + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor +from hezar.trainer import Trainer, TrainerConfig +from hezar.models import Model + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) + +base_model_path = "hezarai/whisper-small" +preprocesssor = Preprocessor.load(base_model_path) + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +# This is the same data collator used in `SpeechRecognitionDataset` +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +# Select needed columns for training +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +# Split dataset for train/evaluation +processed_dataset = processed_dataset.train_test_split(test_size=0.1) + +model = Model.load(base_model_path) + +train_config = TrainerConfig( + output_dir="whisper-small-fa-commonvoice", + task="speech_recognition", + init_weights_from=base_model_path, + mixed_precision="bf16", + gradient_accumulation_steps=8, + batch_size=4, + num_epochs=5, + metrics=["cer", "wer"], +) + +trainer = Trainer( + config=train_config, + model=model, + train_dataset=processed_dataset["train"], + eval_dataset=processed_dataset["test"], + data_collator=data_collator, +) +trainer.train() +``` + +## Wrap-up +Dataset processors are simple, yet powerful callable classes to be used for dataset processing using the `.map()` function +in 🤗Datasets. This integration means that all 🤗Dataset features are unlocked when working with Hezar! diff --git a/docs/guide/index.md b/docs/guide/index.md index e0368c9a..91f21de5 100644 --- a/docs/guide/index.md +++ b/docs/guide/index.md @@ -7,6 +7,7 @@ Welcome to the developer guide section where you can take a deeper dive into the hezar_architecture.md models_advanced.md +dataset_processors.md trainer_in_depth.md advanced_training.md ``` diff --git a/docs/tutorial/datasets.md b/docs/tutorial/datasets.md index b29ab0f8..e877c6da 100644 --- a/docs/tutorial/datasets.md +++ b/docs/tutorial/datasets.md @@ -118,7 +118,96 @@ class ImageCaptioningDataset(Dataset): pass ``` -## Loading Regular HF Datasets +## Loading with 🤗Datasets +You can load all Hezar datasets using the 🤗Datasets library too. Doing so has the following pros and cons: + +**Pros**: +- You can work with any dataset on the Hub and use it easily with Hezar. +- You can leverage multiprocessing and batch processing feature of such datasets (which is not available using torch datasets). +- You can leverage mapped dataset caching provided by 🤗Datasets. +- No integration needed for your old data pipeline codes to make them work with Hezar. + +**Cons**: +- You have to take care of the data processing yourself unless one of the dataset processors at `hezar.data.dataset_processors` suits your needs. + +### Using Hezar's Dataset Processors +In order to replicate the same behavior of the `hezar.data.Dataset` classes for 🤗 loaded dataset, Hezar also implements +a group of dataset processor classes so that you can use them to map the loaded datasets and get the same processed instances +when iterating over your loaded 🤗 datasets. + +Below is a comparison of both methods, using Hezar's torch compatible datasets vs 🤗 Datasets: + +**Loading and Processing with Hezar** + +```python +from torch.utils.data import DataLoader +from hezar.data import SpeechRecognitionDataset, SpeechRecognitionDatasetConfig + +# You can also use the regular `Dataset.load("hezarai/common-voice-13-fa")`, below is for better understanding. +dataset = SpeechRecognitionDataset( + SpeechRecognitionDatasetConfig( + path="hezarai/common-voice-13-fa", + sampling_rate=16000, + audio_file_path_column="path", + audio_column="audio", + audio_array_column="array", + transcript_column="sentence", + ), + split="train", + preprocessor="hezarai/whisper-small-fa", +) + +loader = DataLoader(dataset, batch_size=16, collate_fn=dataset.data_collator) +itr = iter(loader) +print(next(itr)) +``` + +**Loading and Processing with 🤗Datasets and Hezar Dataset Processors** + +```python +from datasets import load_dataset, Audio +from torch.utils.data import DataLoader + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +data_loader = DataLoader(processed_dataset, batch_size=16, collate_fn=data_collator) +x = next(iter(data_loader)) +print(x) +``` +Both codes above, give you the same kind of results. Although using dataset processors is more complicated, but it gives +you more control and better integration with typical data pipelines used nowadays. + +```{note} +You don't necessarily need to use the dataset processor classes in Hezar. They are just there to implement the same +procedures and reproduce the same results. This means that any code that uses 🤗 Datasets will work with Hezar's Trainer. +``` + +### Loading Regular HF Datasets All the current datasets provided in Hezar's Hugging Face, have the `dataset_config.yaml` in their repos which does not exist for regular HF datasets. If you need to load such datasets (that have the correct structure and fields) in Hezar using the `Dataset.load()` method, you have to provide the dataset config manually. From aa123632b20121e3f0545d876a4018ee952ede26 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 14 Nov 2024 10:37:42 +0330 Subject: [PATCH 32/33] :bug: Fix some data processor issues --- hezar/data/dataset_processors.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py index 3d32f084..ad5d615c 100644 --- a/hezar/data/dataset_processors.py +++ b/hezar/data/dataset_processors.py @@ -144,12 +144,12 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non data["pixel_values"] = self.image_processor(path, return_tensors=return_tensors)["pixel_values"] data["labels"] = tokenized_inputs["token_ids"] - data["attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_attention_mask"] = tokenized_inputs["attention_mask"] data["decoder_input_ids"] = self._shift_tokens_right( - data["labels"], + [data["labels"]], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, - ) + )[0] return data @@ -270,9 +270,7 @@ def process_batch(self, data, return_tensors=None): pixel_values = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] # Process text labels in batch - labels = [] - for text in texts: - labels.append(self._text_to_ids(text)) + labels = [self._text_to_ids(text) for text in texts] return {"pixel_values": pixel_values, "labels": labels} @@ -363,7 +361,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non padding=padding, max_length=max_length, ) - + tokenized_inputs = {k: v[0] for k, v in tokenized_inputs.items()} data.update(tokenized_inputs) return data @@ -538,7 +536,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non return_tensors=return_tensors, ) data.update(inputs) - data["labels"] = torch.tensor([label], dtype=torch.long) + data["labels"] = torch.tensor(label, dtype=torch.long) return data From 07ab70341347cf6542203175245d106b7cc9e89c Mon Sep 17 00:00:00 2001 From: arxyzan Date: Thu, 14 Nov 2024 11:55:03 +0330 Subject: [PATCH 33/33] :fire: Remove deprecated tokenizer config args --- hezar/models/speech_recognition/whisper/whisper_tokenizer.py | 3 --- hezar/preprocessors/tokenizers/bpe.py | 3 --- hezar/preprocessors/tokenizers/sentencepiece_bpe.py | 3 --- hezar/preprocessors/tokenizers/sentencepiece_unigram.py | 3 --- hezar/preprocessors/tokenizers/tokenizer.py | 3 --- hezar/preprocessors/tokenizers/wordpiece.py | 3 --- 6 files changed, 18 deletions(-) diff --git a/hezar/models/speech_recognition/whisper/whisper_tokenizer.py b/hezar/models/speech_recognition/whisper/whisper_tokenizer.py index c980e78a..9ff23e9c 100644 --- a/hezar/models/speech_recognition/whisper/whisper_tokenizer.py +++ b/hezar/models/speech_recognition/whisper/whisper_tokenizer.py @@ -254,11 +254,8 @@ @dataclass class WhisperBPEConfig(BPEConfig): name = "whisper_bpe_tokenizer" - max_length: int = 448 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 pad_token: str = "<|endoftext|>" diff --git a/hezar/preprocessors/tokenizers/bpe.py b/hezar/preprocessors/tokenizers/bpe.py index 51802eed..3da3a7dc 100644 --- a/hezar/preprocessors/tokenizers/bpe.py +++ b/hezar/preprocessors/tokenizers/bpe.py @@ -19,11 +19,8 @@ @dataclass class BPEConfig(TokenizerConfig): name = "bpe_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 bos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/sentencepiece_bpe.py b/hezar/preprocessors/tokenizers/sentencepiece_bpe.py index ecb0df1e..36cda5b8 100644 --- a/hezar/preprocessors/tokenizers/sentencepiece_bpe.py +++ b/hezar/preprocessors/tokenizers/sentencepiece_bpe.py @@ -19,11 +19,8 @@ @dataclass class SentencePieceBPEConfig(TokenizerConfig): name = "sentencepiece_bpe_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" bos_token: str = "" eos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/sentencepiece_unigram.py b/hezar/preprocessors/tokenizers/sentencepiece_unigram.py index c38d352d..834b6dc0 100644 --- a/hezar/preprocessors/tokenizers/sentencepiece_unigram.py +++ b/hezar/preprocessors/tokenizers/sentencepiece_unigram.py @@ -19,11 +19,8 @@ @dataclass class SentencePieceUnigramConfig(TokenizerConfig): name = "sentencepiece_unigram_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" bos_token: str = "" eos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 15be87ff..33a83e35 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -42,11 +42,8 @@ class TokenizerConfig(PreprocessorConfig): Configuration for the Tokenizer. Args: - max_length (int): Maximum length of the tokenized sequences. - truncation (str): Truncation strategy for tokenization. truncation_side (str): Truncation direction for tokenization. stride (int): Stride for tokenization. - padding (str): Padding type for tokenization e.g, max_length, longest, no_padding. padding_side (str): Padding direction for tokenization. pad_to_multiple_of (int): Pad to a multiple of this value. pad_token_type_id (int): ID of the padding token type. diff --git a/hezar/preprocessors/tokenizers/wordpiece.py b/hezar/preprocessors/tokenizers/wordpiece.py index b97d30ef..60644fff 100644 --- a/hezar/preprocessors/tokenizers/wordpiece.py +++ b/hezar/preprocessors/tokenizers/wordpiece.py @@ -19,11 +19,8 @@ @dataclass class WordPieceConfig(TokenizerConfig): name = "wordpiece_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 pad_token: str = "[PAD]"