diff --git a/docs/guide/dataset_processors.md b/docs/guide/dataset_processors.md new file mode 100644 index 00000000..6ba99c25 --- /dev/null +++ b/docs/guide/dataset_processors.md @@ -0,0 +1,186 @@ +# Dataset Processors +Initially, Hezar's Trainer worked only with PyTorch Datasets (derived from `torch.utils.data.Dataset`) like all Hezar datasets classes +at `hezar.data.datasets`. Moving on, we also added support for any iterable as the dataset in Hezar's Trainer. + +One really important type of datasets is 🤗 Datasets. The Trainer almost supported these type of datasets since day one, +but implementing the data pipelines must have been handled by the user. That's why Hezar (`v0.42.0>=`) added a new category of classes +called Dataset Processors. These classes are used as dataset map callables which has the following benefits: +- The same processing pipeline in the corresponding `hezar.data.Dataset` subclass is implemented as a map function. +For example, `SpeechRecognitionDatasetProcessor` corresponds to `SpeechRecognitionDataset`. +- Features like cacheing, multiprocessing, batch processing, etc. are now available since objects are of type `datasets.Dataset`. +- Other dataset processing pipelines from other codes feel like plug-and-play to work with Hezar's `Trainer`. + +Now lets see an example demonstrating both cases: + +**Classic 🤗Datasets** + +Here we need to implement a map function that processes our samples. 🤗Datasets `map` function works on callables that +operate on either single or batched inputs. Below is an implementation for batched processing: +```python +from datasets import load_dataset, Audio +from hezar.preprocessors import Preprocessor + + +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") +feature_extractor = preprocesssor.audio_feature_extractor +tokenizer = preprocesssor.tokenizer + +def batch_process_fn(data): + # Extract audio arrays and transcripts + audio_arrays = data["audio"] # Assuming audio arrays are stored under the "audio" key + transcripts = data["transcript"] # Assuming transcripts are stored under the "transcript" key + + # Extract input features in batch + input_features = feature_extractor( + audio_arrays, + sampling_rate=16000, + return_tensors="np", # Return as numpy for compatibility with map + )["input_features"] + + # Tokenize transcripts in batch + labels = tokenizer( + transcripts, + padding="max_length", + max_length=448, + return_tensors="np", + ) + + # Add processed data to the dictionary + data["input_features"] = input_features + data["labels"] = labels["input_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) +# Apply the function to the dataset using map +processed_dataset = dataset.map(batch_process_fn, batched=True) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +print(processed_dataset[0]) +``` + +**Hezar Dataset Processors** + +Here's an equivalent code using the `SpeechRecognitionDatasetProcessor` that has implemented the same map function as a +callable (`SpeechRecognitionDatasetProcessor.__call__()`) that works with both single and batched inputs out of the box! +```python +from datasets import load_dataset, Audio + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) + +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +print(processed_dataset[0]) +``` + +## How Dataset Processors Work +Dataset processors classes are callable classes that receive dataset rows/batches and process them when used as a map function +with `datasets.Dataset.map()`. Here are the current supported dataset processors: +- `ImageCaptioningDatasetProcessor` +- `OCRDatasetProcessor` +- `SequenceLabelingDatasetProcessor` +- `SpeechRecognitionDatasetProcessor` +- `TextClassificationDatasetProcessor` +- `TextSummarizationDatasetProcessor` + +All the above classes inherit from the base `DatasetProcessor` class and must implement the following two methods: +- `process_single(data, **kwargs)` +- `process_batch(data, **kwargs)` + +The main `__call__()` method is implemented in the base class to figure out if the input `data` is a single row or a batch. + + +## A Training Example +Let's see how we can use a dataset processor to load and process a Hub dataset for speech recognition and train a Whisper model. + +```python +from datasets import load_dataset, Audio + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor +from hezar.trainer import Trainer, TrainerConfig +from hezar.models import Model + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) + +base_model_path = "hezarai/whisper-small" +preprocesssor = Preprocessor.load(base_model_path) + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +# This is the same data collator used in `SpeechRecognitionDataset` +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +# Select needed columns for training +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +# Split dataset for train/evaluation +processed_dataset = processed_dataset.train_test_split(test_size=0.1) + +model = Model.load(base_model_path) + +train_config = TrainerConfig( + output_dir="whisper-small-fa-commonvoice", + task="speech_recognition", + init_weights_from=base_model_path, + mixed_precision="bf16", + gradient_accumulation_steps=8, + batch_size=4, + num_epochs=5, + metrics=["cer", "wer"], +) + +trainer = Trainer( + config=train_config, + model=model, + train_dataset=processed_dataset["train"], + eval_dataset=processed_dataset["test"], + data_collator=data_collator, +) +trainer.train() +``` + +## Wrap-up +Dataset processors are simple, yet powerful callable classes to be used for dataset processing using the `.map()` function +in 🤗Datasets. This integration means that all 🤗Dataset features are unlocked when working with Hezar! diff --git a/docs/guide/index.md b/docs/guide/index.md index e0368c9a..91f21de5 100644 --- a/docs/guide/index.md +++ b/docs/guide/index.md @@ -7,6 +7,7 @@ Welcome to the developer guide section where you can take a deeper dive into the hezar_architecture.md models_advanced.md +dataset_processors.md trainer_in_depth.md advanced_training.md ``` diff --git a/docs/tutorial/datasets.md b/docs/tutorial/datasets.md index b29ab0f8..e877c6da 100644 --- a/docs/tutorial/datasets.md +++ b/docs/tutorial/datasets.md @@ -118,7 +118,96 @@ class ImageCaptioningDataset(Dataset): pass ``` -## Loading Regular HF Datasets +## Loading with 🤗Datasets +You can load all Hezar datasets using the 🤗Datasets library too. Doing so has the following pros and cons: + +**Pros**: +- You can work with any dataset on the Hub and use it easily with Hezar. +- You can leverage multiprocessing and batch processing feature of such datasets (which is not available using torch datasets). +- You can leverage mapped dataset caching provided by 🤗Datasets. +- No integration needed for your old data pipeline codes to make them work with Hezar. + +**Cons**: +- You have to take care of the data processing yourself unless one of the dataset processors at `hezar.data.dataset_processors` suits your needs. + +### Using Hezar's Dataset Processors +In order to replicate the same behavior of the `hezar.data.Dataset` classes for 🤗 loaded dataset, Hezar also implements +a group of dataset processor classes so that you can use them to map the loaded datasets and get the same processed instances +when iterating over your loaded 🤗 datasets. + +Below is a comparison of both methods, using Hezar's torch compatible datasets vs 🤗 Datasets: + +**Loading and Processing with Hezar** + +```python +from torch.utils.data import DataLoader +from hezar.data import SpeechRecognitionDataset, SpeechRecognitionDatasetConfig + +# You can also use the regular `Dataset.load("hezarai/common-voice-13-fa")`, below is for better understanding. +dataset = SpeechRecognitionDataset( + SpeechRecognitionDatasetConfig( + path="hezarai/common-voice-13-fa", + sampling_rate=16000, + audio_file_path_column="path", + audio_column="audio", + audio_array_column="array", + transcript_column="sentence", + ), + split="train", + preprocessor="hezarai/whisper-small-fa", +) + +loader = DataLoader(dataset, batch_size=16, collate_fn=dataset.data_collator) +itr = iter(loader) +print(next(itr)) +``` + +**Loading and Processing with 🤗Datasets and Hezar Dataset Processors** + +```python +from datasets import load_dataset, Audio +from torch.utils.data import DataLoader + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Preprocessor + +dataset = load_dataset("hezarai/common-voice-13-fa", split="train") +dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) +dataset = dataset.select_columns(["sentence", "audio"]) +preprocesssor = Preprocessor.load("hezarai/whisper-small-fa") + +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=preprocesssor.tokenizer, + feature_extractor=preprocesssor.audio_feature_extractor, + transcript_column="sentence", + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=preprocesssor.audio_feature_extractor, + tokenizer=preprocesssor.tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + desc="Processing dataset..." +) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +data_loader = DataLoader(processed_dataset, batch_size=16, collate_fn=data_collator) +x = next(iter(data_loader)) +print(x) +``` +Both codes above, give you the same kind of results. Although using dataset processors is more complicated, but it gives +you more control and better integration with typical data pipelines used nowadays. + +```{note} +You don't necessarily need to use the dataset processor classes in Hezar. They are just there to implement the same +procedures and reproduce the same results. This means that any code that uses 🤗 Datasets will work with Hezar's Trainer. +``` + +### Loading Regular HF Datasets All the current datasets provided in Hezar's Hugging Face, have the `dataset_config.yaml` in their repos which does not exist for regular HF datasets. If you need to load such datasets (that have the correct structure and fields) in Hezar using the `Dataset.load()` method, you have to provide the dataset config manually. diff --git a/examples/data/dataset_processing_example.py b/examples/data/dataset_processing_example.py new file mode 100644 index 00000000..a02958c3 --- /dev/null +++ b/examples/data/dataset_processing_example.py @@ -0,0 +1,154 @@ +# from datasets import load_dataset +# +# from hezar.data import ImageCaptioningDatasetProcessor +# from hezar.preprocessors import Tokenizer, ImageProcessor +# +# dataset = load_dataset("hezarai/flickr30k-fa", split="train").select(indices=list(range(4000))) +# tokenizer = Tokenizer.load("hezarai/vit-roberta-fa-image-captioning-flickr30k") +# image_processor = ImageProcessor.load("hezarai/vit-roberta-fa-image-captioning-flickr30k") +# dataset_processor = ImageCaptioningDatasetProcessor(tokenizer=tokenizer, image_processor=image_processor) +# +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# # num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +# +# from datasets import load_dataset +# +# from hezar.data import TextClassificationDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/sentiment-dksf", split="train") +# tokenizer = Tokenizer.load("hezarai/roberta-base-fa") +# dataset_processor = TextClassificationDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# batched=True, +# batch_size=1000, +# load_from_cache_file=False, +# num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + + +# from datasets import load_dataset +# +# from hezar.data import SequenceLabelingDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/lscp-pos-500k", split="train") +# tokenizer = Tokenizer.load("hezarai/roberta-base-fa") +# dataset_processor = SequenceLabelingDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# batched=True, +# batch_size=1000, +# load_from_cache_file=False, +# # num_proc=4, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + + +# from datasets import load_dataset +# +# from hezar.data import TextSummarizationDatasetProcessor +# from hezar.preprocessors import Tokenizer +# +# dataset = load_dataset("hezarai/xlsum-fa", split="train") +# tokenizer = Tokenizer.load("hezarai/t5-base-fa") +# dataset_processor = TextSummarizationDatasetProcessor(tokenizer=tokenizer, padding="longest") +# +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# num_proc=10, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +# from datasets import load_dataset +# +# from hezar.data import OCRDatasetProcessor +# from hezar.preprocessors import ImageProcessor +# from hezar.configs import ModelConfig +# from hezar.utils import is_text_valid +# +# +# dataset = load_dataset("hezarai/parsynth-ocr-200k", split="train[:3000]") +# id2label = ModelConfig.load("hezarai/crnn-fa-printed-96-long", filename="model_config.yaml")["id2label"] # hack +# +# # Cleanup dataset +# max_length = 48 +# valid_indices = [] +# invalid_indices = [] +# for i, sample in enumerate(list(iter(dataset))): +# path, text = sample.values() +# if len(text) <= max_length and is_text_valid(text, id2label.values()): +# valid_indices.append(i) +# dataset = dataset.select(valid_indices) +# +# image_processor = ImageProcessor.load("hezarai/crnn-fa-printed-96-long") +# dataset_processor = OCRDatasetProcessor(image_processor=image_processor, id2label=id2label) +# processed_dataset = dataset.map( +# dataset_processor, +# # batched=True, +# # batch_size=1000, +# load_from_cache_file=False, +# # num_proc=10, +# desc="Processing dataset..." +# ) +# processed_dataset.set_format("torch") +# print(processed_dataset[0]) + +from datasets import load_dataset +from torch.utils.data import DataLoader + +from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator +from hezar.preprocessors import Tokenizer, AudioFeatureExtractor + +dataset = load_dataset("parquet", split="train", data_files=["train-00001-of-00002.parquet"]).select(list(range(100))) +dataset = dataset.select_columns(["sentence", "audio"]) +tokenizer = Tokenizer.load("hezarai/whisper-small-fa") +feature_extractor = AudioFeatureExtractor.load("hezarai/whisper-small-fa") +dataset_processor = SpeechRecognitionDatasetProcessor( + tokenizer=tokenizer, + feature_extractor=feature_extractor, + transcript_field="sentence", + labels_padding=None, + audio_array_padding="max_length", +) +data_collator = SpeechRecognitionDataCollator( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + labels_padding="max_length", + labels_max_length=256, +) +processed_dataset = dataset.map( + dataset_processor, + batched=True, + batch_size=100, + load_from_cache_file=False, + # num_proc=10, + desc="Processing dataset..." +) +processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"]) +processed_dataset.set_format("torch") +data_loader = DataLoader(processed_dataset, batch_size=16, collate_fn=data_collator) +x = next(iter(data_loader)) +print(x) diff --git a/hezar/data/__init__.py b/hezar/data/__init__.py index af95c156..93d4354a 100644 --- a/hezar/data/__init__.py +++ b/hezar/data/__init__.py @@ -1,5 +1,6 @@ from ..registry import datasets_registry # noqa from ..builders import build_dataset # noqa +from .dataset_processors import * from .data_collators import * from .data_samplers import * from .datasets import * diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index b752b56a..3e954968 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import numpy as np import torch @@ -17,6 +19,24 @@ logger = Logger(__name__) +def _convert_to_batch_dict(dicts_list: list[dict]): + """ + Convert a list of dicts to a dict of batched values. + + Args: + dicts_list: A list of dictionaries containing the same set of keys + + Returns: + A dictionary of the batches + """ + batch_dict = defaultdict(list) + for item in dicts_list: + for key, value in item.items(): + batch_dict[key].append(value) + batch_dict = dict(batch_dict) + return batch_dict + + class TextPaddingDataCollator: """ A data collator that pads a batch of tokenized inputs. @@ -52,39 +72,22 @@ def __init__( "attention_mask": 0, } - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None - - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch: A batch dictionary + input_batch: A batch dictionary Returns: Dict: The same batch dictionary but padded """ - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack - - encoded_batch = permuted_batch.copy() - if "label" in encoded_batch: - encoded_batch["labels"] = encoded_batch["label"] - del encoded_batch["label"] - - labels = encoded_batch.pop("labels") - input_length = self.max_length or max(len(x) for x in encoded_batch["token_ids"]) + input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] + input_batch = _convert_to_batch_dict(input_batch) + labels = input_batch.pop("labels") + input_length = self.max_length or max(len(x) for x in input_batch["token_ids"]) - for field, batch in encoded_batch.items(): + for field, batch in input_batch.items(): padded_batch = [] for x in batch: if isinstance(x, torch.Tensor): @@ -95,13 +98,13 @@ def __call__(self, encoded_batch): paddings = [self.field_to_pad_id_mapping[field]] * difference padded_x = x + paddings if self.padding_side == "right" else paddings + x padded_batch.append(padded_x) - encoded_batch[field] = padded_batch + input_batch[field] = padded_batch - encoded_batch["labels"] = labels + input_batch["labels"] = labels - encoded_batch = convert_batch_dict_dtype(encoded_batch, dtype=self.return_tensors) + input_batch = convert_batch_dict_dtype(input_batch, dtype=self.return_tensors) - return encoded_batch + return input_batch class TextGenerationDataCollator: @@ -114,8 +117,7 @@ class TextGenerationDataCollator: padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - max_target_length (int): Maximum target length for text generation. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) + labels_max_length (int): Maximum target length for text generation. """ @@ -125,53 +127,39 @@ def __init__( padding: str = "longest", padding_side: str = "right", max_length: int = None, - max_target_length: int = None, - return_tensors: str = "torch", + labels_max_length: int = None, ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.max_length = max_length - self.max_target_length = max_target_length - self.return_tensors = return_tensors - - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None + self.labels_max_length = labels_max_length - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch (List[Dict]): A batch dictionary + input_batch (List[Dict]): A batch dictionary Returns: Dict: The same batch dictionary but padded """ - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack - + input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] + input_batch = _convert_to_batch_dict(input_batch) padded_batch = self.tokenizer.pad_encoded_batch( - permuted_batch, + input_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["labels"], - return_tensors=self.return_tensors, + return_tensors="torch", ) padded_batch = self.tokenizer.pad_encoded_batch( padded_batch, padding=self.padding, - max_length=self.max_target_length, + max_length=self.labels_max_length, include_keys=["labels"], - return_tensors=self.return_tensors, + return_tensors="torch", ) return padded_batch @@ -187,7 +175,6 @@ class ImageCaptioningDataCollator: padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ def __init__( @@ -196,39 +183,30 @@ def __init__( padding: str = "longest", padding_side: str = "right", max_length: int = None, - return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.max_length = max_length - self.return_tensors = return_tensors - - if padding == "longest" and max_length is not None: - logger.warning( - "You passed `max_length` while also setting `padding` to `longest` which are " - "incompatible! Instead leave `max_length` as None or set `padding` to `max_length`! " - "Ignoring `max_length`" - ) - self.max_length = None - - def __call__(self, encoded_batch): - encoded_batch = [convert_batch_dict_dtype(x, dtype="list") for x in encoded_batch] - permuted_batch = {} - for key in encoded_batch[0].keys(): - stack = [e for item in encoded_batch for e in item[key]] - permuted_batch[key] = stack - padded_batch = self.tokenizer.pad_encoded_batch( - permuted_batch, + def __call__(self, input_batch): + input_batch = _convert_to_batch_dict(input_batch) + input_batch = self.tokenizer.pad_encoded_batch( + input_batch, padding=self.padding, max_length=self.max_length, exclude_keys=["pixel_values"], - return_tensors=self.return_tensors, + return_tensors="torch", ) - padded_batch = convert_batch_dict_dtype(padded_batch, dtype="torch") + if isinstance(input_batch["pixel_values"], list): + if isinstance(input_batch["pixel_values"][0], list): + input_batch["pixel_values"] = torch.tensor(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"][0], np.ndarray): + input_batch["pixel_values"] = torch.stack([torch.from_numpy(x) for x in input_batch["pixel_values"]]) - return padded_batch + return input_batch class SpeechRecognitionDataCollator: @@ -249,18 +227,13 @@ def __init__( self.labels_max_length = labels_max_length def __call__(self, input_batch): - input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch] - inputs = {} - for key in input_batch[0].keys(): - stack = [e for item in input_batch for e in item[key]] - inputs[key] = stack - + input_batch = _convert_to_batch_dict(input_batch) inputs = self.tokenizer.pad_encoded_batch( - inputs, + input_batch, padding=self.labels_padding, max_length=self.labels_max_length, exclude_keys=["input_features"], - return_tensors="torch" + return_tensors="torch", ) inputs = self.feature_extractor.pad( @@ -280,11 +253,10 @@ class SequenceLabelingDataCollator: Args: tokenizer (Tokenizer): A Hezar tokenizer instance. padding (str): Specifies padding strategy, either `longest` or `max_length`. - padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right` + padding_side (str): Specifies from which side of each tensor to add paddings, either `left` or `right`. label_pad_token_id (int): Token ID for padding labels. max_length (int): If `padding` is set to `max_length` this must be specified. Forces all tensors to have this value as length. - return_tensors (str): Specifies the dtype of the returning tensors in the batch. (`numpy`, `list`, `torch`) """ def __init__( @@ -294,54 +266,52 @@ def __init__( padding_side: str = "right", label_pad_token_id: int = -100, max_length: int = None, - return_tensors: str = "torch", ): self.tokenizer = tokenizer self.padding = padding self.padding_side = padding_side self.label_pad_token_id = label_pad_token_id self.max_length = max_length - self.return_tensors = return_tensors - def __call__(self, encoded_batch): + def __call__(self, input_batch): """ Add padding to every item in the batch Args: - encoded_batch (List[Dict]): A batch dictionary + input_batch (List[Dict]): A batch dictionary Returns: Dict: The same batch dictionary but padded """ - label_name = "label" if "label" in encoded_batch[0].keys() else "labels" - labels = [feature[label_name] for feature in encoded_batch] if label_name in encoded_batch[0].keys() else None + input_batch = _convert_to_batch_dict(input_batch) + labels = input_batch["labels"] self.tokenizer.config.padding_side = self.padding_side - batch = self.tokenizer.pad_encoded_batch( - encoded_batch, + input_batch = self.tokenizer.pad_encoded_batch( + input_batch, padding=self.padding, # noqa max_length=self.max_length, - # Conversion to tensors will fail if we have labels as they are not of the same length yet. - return_tensors="torch" if labels is None else None, + return_tensors="torch", ) if labels is None: - return batch + return input_batch - batch.pop("word_ids", None) - sequence_length = torch.tensor(batch["token_ids"]).shape[1] + input_batch.pop("word_ids", None) + sequence_length = input_batch["token_ids"].shape[1] if self.padding_side == "right": - batch[label_name] = [ + input_batch["labels"] = [ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] else: - batch[label_name] = [ + input_batch["labels"] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels ] - batch = { - k: torch.tensor(v, dtype=torch.int64) if not isinstance(v, torch.Tensor) else v for k, v in batch.items() + input_batch = { + k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in input_batch.items() } - return batch + + return input_batch class CharLevelOCRDataCollator: @@ -365,14 +335,16 @@ def __call__(self, input_batch): Returns: Dict: Padded input batch. """ - if isinstance(input_batch, (list, tuple)) and isinstance(input_batch[0], dict): - input_batch = {key: [example[key] for example in input_batch] for key in input_batch[0].keys()} - input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"], 0) + input_batch = _convert_to_batch_dict(input_batch) + + if not isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.tensor(input_batch["pixel_values"]) + elif isinstance(input_batch["pixel_values"], list) and isinstance(input_batch["pixel_values"][0], torch.Tensor): + input_batch["pixel_values"] = torch.stack(input_batch["pixel_values"]) max_length = max(map(len, input_batch["labels"])) all_labels = [] for labels in input_batch["labels"]: - labels = labels.numpy().tolist() labels += [self.pad_token_id] * (max_length - len(labels)) all_labels.append(labels) input_batch["labels"] = torch.tensor(all_labels) diff --git a/hezar/data/dataset_processors.py b/hezar/data/dataset_processors.py new file mode 100644 index 00000000..ad5d615c --- /dev/null +++ b/hezar/data/dataset_processors.py @@ -0,0 +1,687 @@ +""" +Dataset processors are a bunch of callable classes to be passed as map functions for any dataset on the Hub. +Note that the main dataset classes are already implemented in a way that the processing is done in the `__getitem__` +method and these classes are only used for when the dataset has been loaded using the HuggingFace datasets library, +and you want to get advantage of the multiprocessing/batch processing/caching functionalities of the HF datasets. + +Example: +>>> from datasets import load_dataset +>>> from hezar.data import SpeechRecognitionDatasetProcessor + +>>> data_processor = SpeechRecognitionDatasetProcessor(feature_extractor=feature_extractor,tokenizer=tokenizer) +>>> dataset = load_dataset("hezarai/common-voice-13-fa") +>>> dataset = dataset.map(data_processor, batched=True, batch_size=1000) +""" + +import torch + +from ..constants import Backends +from ..utils import is_backend_available, reverse_string_digits, verify_dependencies + + +if is_backend_available(Backends.DATASETS): + from datasets.formatting.formatting import LazyBatch, LazyRow + +__all__ = [ + "DatasetProcessor", + "ImageCaptioningDatasetProcessor", + "OCRDatasetProcessor", + "SequenceLabelingDatasetProcessor", + "SpeechRecognitionDatasetProcessor", + "TextClassificationDatasetProcessor", + "TextSummarizationDatasetProcessor", +] + + +class DatasetProcessor: + """ + The base callable dataset processor class that can handle both single and batched mode dataset mapping. + """ + required_backends = [Backends.DATASETS] + + def __init__(self, *args, **kwargs): + verify_dependencies(self, self.required_backends) + self.args = args + self.kwargs = kwargs + + def __call__(self, data: LazyBatch | LazyRow, return_tensors="list", **kwargs): + """ + Method called when using the map function. + Decides whether to call `process_single()` or `process_batch()` based on the data values. + + Args: + data: A dict of feature name -> sample or batch of samples mapping. + return_tensors: The type of the returning tensors (list, torch, numpy) + **kwargs: Additional keyword arguments passed through the `map` function as `kwargs`. + """ + if isinstance(data, LazyRow): + return self.process_single(data, return_tensors=return_tensors, **kwargs) + elif isinstance(data, LazyBatch): + return self.process_batch(data, return_tensors=return_tensors, **kwargs) + else: + raise ValueError(f"The input data must be either `LazyBatch` or `LazyRow`, got `{type(data)}`!") + + def process_single(self, data: LazyRow, return_tensors=None, **kwargs): + """ + Process a single data example. + + Args: + data: A data sample dict + return_tensors: The type of the returning tensors (list, torch, numpy) + **kwargs: Additional arguments + + Returns: + The updated data dict + """ + raise NotImplementedError + + def process_batch(self, data: LazyBatch, return_tensors=None, **kwargs): + """ + Process a batch of data examples. + + Args: + data: A data sample dict + return_tensors: The type of the returning tensors (list, torch, numpy) + **kwargs: Additional arguments + + Returns: + The updated data dict + """ + raise NotImplementedError + + +class ImageCaptioningDatasetProcessor(DatasetProcessor): + """ + Dataset processor for image captioning datasets. This class handles tokenization and image processing. + """ + + def __init__(self, image_processor, tokenizer, max_length=None, padding=None): + super().__init__() + self.image_processor = image_processor + self.tokenizer = tokenizer + self.max_length = max_length + self.padding = padding + + @staticmethod + def _shift_tokens_right(input_ids: list[list[int]], pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + # Initialize shifted_input_ids as a list of lists with the same shape as input_ids + shifted_input_ids = [[0] * len(row) for row in input_ids] + + for i, row in enumerate(input_ids): + # Shift each row one token to the right + shifted_input_ids[i][1:] = row[:-1] + # Set the first token to decoder_start_token_id + shifted_input_ids[i][0] = decoder_start_token_id + + # Replace any -100 values with pad_token_id + shifted_input_ids[i] = [pad_token_id if token == -100 else token for token in shifted_input_ids[i]] + + return shifted_input_ids + + def process_single(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process image and tokenize captions for a single data sample. + + Args: + data: A data example containing the image and its caption + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + A dict of pixel values tensor of the processed image and labels token ids and attention mask. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + path = data["image_path"] + text = data["label"] + + tokenized_inputs = self.tokenizer(text, padding=padding, max_length=max_length, return_tensors=return_tensors) + + data["pixel_values"] = self.image_processor(path, return_tensors=return_tensors)["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["decoder_attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = self._shift_tokens_right( + [data["labels"]], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + )[0] + + return data + + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process image and tokenize captions for a batch of data samples. + + Args: + data: A batch of data examples containing the images and their captions + padding: Padding type e.g, max_length, longest. + max_length: Max length value if padding is set to max_length or the labels must be truncated. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + A dict of pixel values tensor of the processed images and labels token ids and attention masks. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + paths = data["image_path"] + texts = data["label"] + + tokenized_inputs = self.tokenizer(texts, padding=padding, max_length=max_length, return_tensors=return_tensors) + + data["pixel_values"] = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] + data["labels"] = tokenized_inputs["token_ids"] + data["decoder_attention_mask"] = tokenized_inputs["attention_mask"] + data["decoder_input_ids"] = self._shift_tokens_right( + data["labels"], + pad_token_id=self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + ) + + return data + + +class OCRDatasetProcessor(DatasetProcessor): + """ + Dataset processor class for OCR which can handle both tokenizer-based or character-split-based datasets. + """ + + def __init__( + self, + image_processor, + tokenizer=None, + text_split_type="char_split", + max_length=None, + reverse_digits=False, + id2label=None, + image_field="image_path", + text_field="text", + ): + super().__init__() + self.image_processor = image_processor + self.tokenizer = tokenizer + self.text_split_type = text_split_type + self.max_length = max_length + self.reverse_digits = reverse_digits + self.id2label = id2label + self.image_field = image_field + self.text_field = text_field + + def _text_to_ids(self, text): + """ + Convert text to tensor based on the configured text_split_type. + + Args: + text (str): The raw text. + + Returns: + torch.Tensor: The output tensor. + + """ + if self.text_split_type == "tokenize": + token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"] + labels = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] + elif self.text_split_type == "char_split": + if self.reverse_digits: + text = reverse_string_digits(text) + label2id = {v: k for k, v in self.id2label.items()} + labels = [label2id[char] for char in text] + else: + raise ValueError(f"Invalid `text_split_type={self.text_split_type}`") + return labels + + def process_single(self, data, return_tensors=None): + """ + Process a single image-to-text OCR example. + + Args: + data: A data example containing an image path and corresponding text. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + dict: Processed inputs with pixel values and text labels. + """ + path = data[self.image_field] + text = data[self.text_field] + pixel_values = self.image_processor(path, return_tensors=return_tensors)["pixel_values"][0] + labels = self._text_to_ids(text) + return {"pixel_values": pixel_values, "labels": labels} + + def process_batch(self, data, return_tensors=None): + """ + Process a batch of image-to-text OCR examples. + + Args: + data: A batch of data examples containing image paths and corresponding texts. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + dict: Batch of processed inputs with pixel values and text labels. + """ + paths = data[self.image_field] + texts = data[self.text_field] + + # Process images in batch + pixel_values = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"] + + # Process text labels in batch + labels = [self._text_to_ids(text) for text in texts] + + return {"pixel_values": pixel_values, "labels": labels} + + +class SequenceLabelingDatasetProcessor(DatasetProcessor): + """ + Dataset processor class for sequence labeling datasets. Handles tokenization and label alignment. + """ + + def __init__(self, tokenizer, label_all_tokens=True, ignore_index=-100, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.label_all_tokens = label_all_tokens + self.ignore_index = ignore_index + self.max_length = max_length + self.padding = padding + + def _tokenize_and_align(self, tokens, labels, return_tensors=None, padding=None, max_length=None): + """ + Tokenize and align tokens and labels for sequence labeling tasks. + + Args: + tokens: List of tokens (for single examples) or list of lists (for batches). + labels: List of labels (for single examples) or list of lists (for batches). + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Padding strategy for tokenization. + max_length: Maximum sequence length to truncate/pad. + + Returns: + dict: Tokenized and aligned inputs with labels. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + # Tokenize and return word IDs for mapping labels to subword tokens + tokenized_inputs = self.tokenizer( + tokens, + is_split_into_words=True, + return_word_ids=True, + padding=padding, + truncation=True, + max_length=max_length, + return_tensors=return_tensors + ) + word_ids = tokenized_inputs["word_ids"] + + # Align labels with tokens + aligned_labels = [] + for batch_idx, batch_word_ids in enumerate(word_ids): + previous_word_idx = None + label_ids = [] + for word_idx in batch_word_ids: + # Assign ignore index for special tokens + if word_idx is None: + label_ids.append(self.ignore_index) + elif word_idx != previous_word_idx: + # Assign the label for the first token of each word + label_ids.append(labels[batch_idx][word_idx]) + else: + # Assign label for subword tokens (if label_all_tokens is True) + label_ids.append(labels[batch_idx][word_idx] if self.label_all_tokens else self.ignore_index) + previous_word_idx = word_idx + aligned_labels.append(label_ids) + + tokenized_inputs["labels"] = aligned_labels + return tokenized_inputs + + def process_single(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process a single example of sequence labeling data. + + Args: + data: A single data example containing tokens and labels. + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned input data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align( + [tokens], + [labels], + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + ) + tokenized_inputs = {k: v[0] for k, v in tokenized_inputs.items()} + data.update(tokenized_inputs) + + return data + + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process a batch of sequence labeling examples. + + Args: + data: A batch of examples, containing tokens and labels. + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Padding strategy. + max_length: Maximum sequence length. + + Returns: + dict: Tokenized and aligned batch data. + """ + tokens = data["tokens"] + labels = data["pos_tags"] + + tokenized_inputs = self._tokenize_and_align( + tokens, + labels, + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + ) + + data.update(tokenized_inputs) + + return data + + +class SpeechRecognitionDatasetProcessor(DatasetProcessor): + """ + Processor class for speech recognition datasets. Handles audio feature extraction and labels tokenization. + """ + + def __init__( + self, + feature_extractor, + tokenizer, + sampling_rate=16000, + audio_array_padding=None, + max_audio_array_length=None, + labels_padding=None, + labels_max_length=None, + audio_column="audio", + transcript_column="transcript", + ): + super().__init__() + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.sampling_rate = sampling_rate + self.audio_array_padding = audio_array_padding + self.max_audio_array_length = max_audio_array_length + self.labels_padding = labels_padding + self.labels_max_length = labels_max_length + self.audio_column = audio_column + self.transcript_column = transcript_column + + def process_single(self, data, return_tensors=None): + """ + Process a single speech recognition example. + + Args: + data: A data example containing audio and its transcript. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + dict: Processed input features and labels. + """ + audio_array = data[self.audio_column]["array"] + transcript = data[self.transcript_column] + + # Extract input features from audio + input_features = self.feature_extractor( + audio_array, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors=return_tensors, + )["input_features"] + + # Tokenize the transcript + labels = self.tokenizer( + transcript, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors=return_tensors, + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + + def process_batch(self, data, return_tensors=None): + """ + Process a batch of speech recognition examples. + + Args: + data: A batch of data examples containing audio arrays and their corresponding transcripts. + return_tensors: The type of the returning tensors (list, torch, numpy) + + Returns: + dict: Batch of processed input features and labels. + """ + audio_arrays = [x["array"] for x in data[self.audio_column]] + transcripts = data[self.transcript_column] + + # Extract input features in batch + input_features = self.feature_extractor( + audio_arrays, + sampling_rate=self.sampling_rate, + padding=self.audio_array_padding, + max_length=self.max_audio_array_length, + return_tensors=return_tensors, + )["input_features"] + + # Tokenize transcripts in batch + labels = self.tokenizer( + transcripts, + padding=self.labels_padding, + max_length=self.labels_max_length, + return_tensors=return_tensors, + ) + + data["input_features"] = input_features + data["labels"] = labels["token_ids"] + data["attention_mask"] = labels["attention_mask"] + + return data + + +class TextClassificationDatasetProcessor(DatasetProcessor): + """ + Processor class for text classification datasets. Handles tokenization of the texts. + """ + + def __init__(self, tokenizer, max_length=None, padding=None): + super().__init__() + self.tokenizer = tokenizer + self.padding = padding + self.max_length = max_length + + def process_single(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process a single example for text classification. + + Args: + data: A single data example dict + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Token ids padding type + max_length: Max input length + + Returns: + The updated data dictionary + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + text = data["text"] + label = data["label"] + + inputs = self.tokenizer( + text, + padding=padding, + max_length=max_length, + return_attention_mask=True, + return_tensors=return_tensors, + ) + data.update(inputs) + data["labels"] = torch.tensor(label, dtype=torch.long) + + return data + + def process_batch(self, data, return_tensors=None, padding=None, max_length=None): + """ + Process a batch of examples for text classification. + + Args: + data: A single data example dict + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Token ids padding type + max_length: Max input length + + Returns: + The updated data dictionary + """ + padding = padding or self.padding + max_length = max_length or self.max_length + + texts = data["text"] + labels = data["label"] + + inputs = self.tokenizer( + texts, + padding=padding, + max_length=max_length, + return_attention_mask=True, + return_tensors=return_tensors, + ) + data.update(inputs) + data["labels"] = torch.tensor(labels, dtype=torch.long) + + return data + + +class TextSummarizationDatasetProcessor(DatasetProcessor): + """ + Processor class for text summarization datasets. Handles tokenization of the inputs and labels. + """ + + def __init__( + self, + tokenizer, + prefix=None, + max_length=None, + labels_max_length=None, + text_field="text", + summary_field="summary", + padding=None, + ): + super().__init__() + self.tokenizer = tokenizer + self.prefix = prefix + self.max_length = max_length + self.labels_max_length = labels_max_length + self.text_field = text_field + self.summary_field = summary_field + self.padding = padding + + def process_single(self, data, return_tensors=None, padding=None, max_length=None, labels_max_length=None): + """ + Process a single example for text summarization. + + Args: + data: A data example containing text and summary. + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Padding strategy. + max_length: Max length for input text. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + text = data[self.text_field] + summary = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + text = self.prefix + text + + # Tokenize inputs and labels + inputs = self.tokenizer( + text, + max_length=max_length, + padding=padding, + return_attention_mask=True, + return_tensors=return_tensors, + ) + labels = self.tokenizer( + summary, + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + return_tensors=return_tensors, + ) + + inputs["labels"] = labels["token_ids"] + + return inputs + + def process_batch(self, data, return_tensors=None, padding=None, max_length=None, labels_max_length=None): + """ + Process a batch of examples for text summarization. + + Args: + data: A batch of examples containing texts and summaries. + return_tensors: The type of the returning tensors (list, torch, numpy) + padding: Padding strategy. + max_length: Max length for input texts. + labels_max_length: Max length for summary labels. + + Returns: + dict: Tokenized inputs and labels for summarization task. + """ + padding = padding or self.padding + max_length = max_length or self.max_length + labels_max_length = labels_max_length or self.labels_max_length + + texts = data[self.text_field] + summaries = data[self.summary_field] + + # Add prefix if needed for conditional generation + if self.prefix is not None: + texts = [self.prefix + text for text in texts] + + # Tokenize inputs and labels in batch + inputs = self.tokenizer( + texts, + max_length=max_length, + padding=padding, + return_attention_mask=True, + return_tensors=return_tensors, + ) + labels = self.tokenizer( + summaries, + max_length=labels_max_length, + padding=padding, + return_attention_mask=True, + return_tensors=return_tensors, + ) + + inputs["labels"] = labels["token_ids"] + + return inputs diff --git a/hezar/data/datasets/image_captioning_dataset.py b/hezar/data/datasets/image_captioning_dataset.py index 51e53fd4..29505828 100644 --- a/hezar/data/datasets/image_captioning_dataset.py +++ b/hezar/data/datasets/image_captioning_dataset.py @@ -77,19 +77,17 @@ def __getitem__(self, index): dict: The input data. """ path, text = self.data[index].values() - pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"] + pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] tokenized_inputs = self.tokenizer(text, padding="max_length", max_length=self.config.max_length) - labels = torch.tensor([tokenized_inputs["token_ids"]]) - attention_mask = torch.tensor([tokenized_inputs["attention_mask"]]) decoder_input_ids = shift_tokens_right( - labels, + [tokenized_inputs["token_ids"]], pad_token_id=self.tokenizer.pad_token_id, decoder_start_token_id=self.tokenizer.bos_token_id, - ) + )[0] inputs = { "pixel_values": pixel_values, - "labels": labels, + "labels": tokenized_inputs["token_ids"], "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": attention_mask, + "decoder_attention_mask": tokenized_inputs["attention_mask"], } return inputs diff --git a/hezar/data/datasets/ocr_dataset.py b/hezar/data/datasets/ocr_dataset.py index e0d05110..c11fb154 100644 --- a/hezar/data/datasets/ocr_dataset.py +++ b/hezar/data/datasets/ocr_dataset.py @@ -123,7 +123,7 @@ def _load(self, split=None): data = data.select(valid_indices) return data - def _text_to_tensor(self, text): + def _text_to_ids(self, text): """ Convert text to tensor based on the configured text_split_type. @@ -139,14 +139,12 @@ def _text_to_tensor(self, text): token_ids = self.tokenizer(text, padding="max_length", max_length=self.config.max_length)["token_ids"] # Make sure to ignore pad tokens by the loss function token_ids = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids] - labels = torch.tensor(token_ids) # If split text is not tokenizer-based elif self.config.text_split_type == TextSplitType.CHAR_SPLIT: if self.config.reverse_digits: text = reverse_string_digits(text) label2id = {v: k for k, v in self.config.id2label.items()} labels = [label2id[x] for x in text] - labels = torch.LongTensor(labels) else: raise ValueError(f"Invalid `text_split_type={self.config.text_split_type}`") @@ -165,7 +163,7 @@ def __getitem__(self, index): """ path, text = self.data[index].values() pixel_values = self.image_processor(path, return_tensors="torch")["pixel_values"][0] - labels = self._text_to_tensor(text) + labels = self._text_to_ids(text) inputs = { "pixel_values": pixel_values, "labels": labels, diff --git a/hezar/data/datasets/speech_recognition_dataset.py b/hezar/data/datasets/speech_recognition_dataset.py index d8a5c3f2..1d85c3e1 100644 --- a/hezar/data/datasets/speech_recognition_dataset.py +++ b/hezar/data/datasets/speech_recognition_dataset.py @@ -62,14 +62,13 @@ def __getitem__(self, index): input_features = self.feature_extractor( audio_array, sampling_rate=self.config.sampling_rate, - return_tensors="torch" - )["input_features"] + return_tensors="numpy" + )["input_features"][0] labels = self.tokenizer( transcript, padding=self.config.labels_padding, max_length=self.config.labels_max_length, - return_tensors="torch", ) return { diff --git a/hezar/data/datasets/text_classification_dataset.py b/hezar/data/datasets/text_classification_dataset.py index b27ad858..6e8a5c9d 100644 --- a/hezar/data/datasets/text_classification_dataset.py +++ b/hezar/data/datasets/text_classification_dataset.py @@ -94,14 +94,7 @@ def __getitem__(self, index): """ text = self.data[index][self.config.text_field] label = self.data[index][self.config.label_field] - inputs = self.tokenizer( - text, - return_tensors="torch", - truncation=True, - padding="longest", - return_attention_mask=True, - ) - label_idx = torch.tensor([label], dtype=torch.long) # noqa - inputs["labels"] = label_idx + inputs = self.tokenizer(text, return_attention_mask=True) + inputs["labels"] = label return inputs diff --git a/hezar/data/datasets/text_summarization_dataset.py b/hezar/data/datasets/text_summarization_dataset.py index b85f1d3f..2dfd4f0f 100644 --- a/hezar/data/datasets/text_summarization_dataset.py +++ b/hezar/data/datasets/text_summarization_dataset.py @@ -26,7 +26,7 @@ class TextSummarizationDatasetConfig(DatasetConfig): summary_field (str): Field name for summary in the dataset. title_field (str): Field name for title in the dataset. max_length (int): Maximum length of text. - max_target_length (int): Maximum length of the target summary. + labels_max_length (int): Maximum length of the target summary. """ name = "text_summarization" @@ -37,7 +37,7 @@ class TextSummarizationDatasetConfig(DatasetConfig): summary_field: str = None title_field: str = None max_length: int = None - max_target_length: int = None + labels_max_length: int = None @register_dataset("text_summarization", config_class=TextSummarizationDatasetConfig) @@ -58,8 +58,8 @@ def __init__(self, config: TextSummarizationDatasetConfig, split=None, preproces self.data_collator = TextGenerationDataCollator( tokenizer=self.tokenizer, max_length=self.config.max_length, - max_target_length=self.config.max_target_length, - padding="max_length" if self.config.max_length else "longest", + labels_max_length=self.config.labels_max_length, + padding="max_length" if self.config.max_length else None, ) def _load(self, split): @@ -94,16 +94,14 @@ def __getitem__(self, index): inputs = self.tokenizer( text, - return_tensors="torch", max_length=self.config.max_length, - padding="max_length" if self.config.max_length else "longest", + padding="max_length" if self.config.max_length else None, return_attention_mask=True, ) labels = self.tokenizer( summary, - return_tensors="torch", max_length=self.config.max_length, - padding="max_length" if self.config.max_target_length else "longest", + padding="max_length" if self.config.labels_max_length else None, return_attention_mask=True, ) diff --git a/hezar/models/speech_recognition/whisper/whisper_tokenizer.py b/hezar/models/speech_recognition/whisper/whisper_tokenizer.py index c980e78a..9ff23e9c 100644 --- a/hezar/models/speech_recognition/whisper/whisper_tokenizer.py +++ b/hezar/models/speech_recognition/whisper/whisper_tokenizer.py @@ -254,11 +254,8 @@ @dataclass class WhisperBPEConfig(BPEConfig): name = "whisper_bpe_tokenizer" - max_length: int = 448 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 pad_token: str = "<|endoftext|>" diff --git a/hezar/preprocessors/image_processor.py b/hezar/preprocessors/image_processor.py index 3665966a..4d30d16a 100644 --- a/hezar/preprocessors/image_processor.py +++ b/hezar/preprocessors/image_processor.py @@ -88,7 +88,7 @@ def __init__(self, config: ImageProcessorConfig, **kwargs): def __call__( self, - images: List, + images: str | List, device: str = None, mean: float = None, std: float = None, @@ -104,7 +104,7 @@ def __call__( Perform sequential image processing on a list of input images. Args: - images (List): A list of input images of types torch, numpy, pillow. + images (str | List): A list of input images (torch, numpy, pillow) OR path or list of paths to images. mean (float): Image mean value for normalization. std (float): Image std value for normalization. rescale (float): Scale factor for rescaling the image. @@ -126,8 +126,10 @@ def __call__( mirror = mirror or self.config.mirror gray_scale = gray_scale or self.config.gray_scale + is_single = False if not isinstance(images, list) or isinstance(images, str) or isinstance(images, np.ndarray): images = [images] + is_single = True # Load images if inputs are list of files images = [load_image(x, return_type="numpy") if isinstance(x, str) else x for x in images] @@ -162,11 +164,14 @@ def __call__( images = convert_batch_dict_dtype({"pixel_values": images}, dtype=return_tensors) - if device: + if device and return_tensors == "torch": import torch images = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in images.items()} + if is_single and return_tensors == "list": + images = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in images.items()} + return images @classmethod diff --git a/hezar/preprocessors/tokenizers/bpe.py b/hezar/preprocessors/tokenizers/bpe.py index 51802eed..3da3a7dc 100644 --- a/hezar/preprocessors/tokenizers/bpe.py +++ b/hezar/preprocessors/tokenizers/bpe.py @@ -19,11 +19,8 @@ @dataclass class BPEConfig(TokenizerConfig): name = "bpe_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 bos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/sentencepiece_bpe.py b/hezar/preprocessors/tokenizers/sentencepiece_bpe.py index ecb0df1e..36cda5b8 100644 --- a/hezar/preprocessors/tokenizers/sentencepiece_bpe.py +++ b/hezar/preprocessors/tokenizers/sentencepiece_bpe.py @@ -19,11 +19,8 @@ @dataclass class SentencePieceBPEConfig(TokenizerConfig): name = "sentencepiece_bpe_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" bos_token: str = "" eos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/sentencepiece_unigram.py b/hezar/preprocessors/tokenizers/sentencepiece_unigram.py index c38d352d..834b6dc0 100644 --- a/hezar/preprocessors/tokenizers/sentencepiece_unigram.py +++ b/hezar/preprocessors/tokenizers/sentencepiece_unigram.py @@ -19,11 +19,8 @@ @dataclass class SentencePieceUnigramConfig(TokenizerConfig): name = "sentencepiece_unigram_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" bos_token: str = "" eos_token: str = "" diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 732fccd1..33a83e35 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -42,11 +42,8 @@ class TokenizerConfig(PreprocessorConfig): Configuration for the Tokenizer. Args: - max_length (int): Maximum length of the tokenized sequences. - truncation (str): Truncation strategy for tokenization. truncation_side (str): Truncation direction for tokenization. stride (int): Stride for tokenization. - padding (str): Padding type for tokenization e.g, max_length, longest, no_padding. padding_side (str): Padding direction for tokenization. pad_to_multiple_of (int): Pad to a multiple of this value. pad_token_type_id (int): ID of the padding token type. @@ -61,13 +58,13 @@ class TokenizerConfig(PreprocessorConfig): """ name = "tokenizer" - max_length: int = None - truncation: str = None + max_length: int = "deprecated" + truncation: str = "deprecated" truncation_side: str = None - padding: str = None + padding: str = "deprecated" padding_side: str = None stride: int = None - pad_to_multiple_of: int = None + pad_to_multiple_of: int = "deprecated" pad_token_type_id: int = 0 bos_token: str = None eos_token: str = None @@ -78,6 +75,21 @@ class TokenizerConfig(PreprocessorConfig): mask_token: str = None additional_special_tokens: List[str] = None + def __post_init__(self): + super().__post_init__() + if self.max_length != "deprecated": + logger.warning( + "Setting `max_length` in the tokenizer config is deprecated and will be removed in the future!" + ) + if self.padding != "deprecated": + logger.warning( + "Setting `padding` in the tokenizer config is deprecated and will be removed in the future!" + ) + if self.truncation != "deprecated": + logger.warning( + "Setting `truncation` in the tokenizer config is deprecated and will be removed in the future!" + ) + class Tokenizer(Preprocessor): """ @@ -304,6 +316,8 @@ def __call__( " This warning will change to an error in the future!" ) + return_tensors = return_tensors or "list" + # Convert to batch if input is a single string or a list of words (is split into words for sequence labeling) if isinstance(inputs, str) or (is_split_into_words and not isinstance(inputs[0], list)): inputs = [inputs] @@ -311,12 +325,6 @@ def __call__( else: is_batch = True - if padding is None and max_length is not None: - padding = PaddingType.MAX_LENGTH - truncation = truncation or self.config.truncation - max_length = max_length or self.config.max_length - pad_to_multiple_of = pad_to_multiple_of or self.config.pad_to_multiple_of - self.set_truncation_and_padding( padding=padding, truncation=truncation, @@ -359,6 +367,7 @@ def __call__( overflow_to_sample_mapping += [i] * len(encodings_["input_ids"]) sanitized_outputs["overflow_to_sample_mapping"] = overflow_to_sample_mapping + # Squeeze tensor if the original input is a single string and return_tensors is `list` if (return_tensors == "list" or return_tensors is None) and not is_batch: sanitized_outputs = { key: value[0] if len(value) > 0 and isinstance(value[0], list) else value @@ -366,6 +375,7 @@ def __call__( } outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys) + if device and return_tensors == "torch": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} @@ -382,7 +392,7 @@ def set_truncation_and_padding( pad_to_multiple_of: int = None, ): # Set truncation and padding on the backend tokenizer - if truncation == "no_truncation" or truncation is None: + if truncation == "no_truncation" or truncation is None or max_length is None: if self.truncation is not None: self.no_truncation() else: diff --git a/hezar/preprocessors/tokenizers/wordpiece.py b/hezar/preprocessors/tokenizers/wordpiece.py index b97d30ef..60644fff 100644 --- a/hezar/preprocessors/tokenizers/wordpiece.py +++ b/hezar/preprocessors/tokenizers/wordpiece.py @@ -19,11 +19,8 @@ @dataclass class WordPieceConfig(TokenizerConfig): name = "wordpiece_tokenizer" - max_length: int = 512 - truncation: str = "longest_first" truncation_side: str = "right" stride: int = 0 - padding: str = "longest" padding_side: str = "right" pad_to_multiple_of: int = 0 pad_token: str = "[PAD]" diff --git a/hezar/utils/data_utils.py b/hezar/utils/data_utils.py index 2c714435..bf04001a 100644 --- a/hezar/utils/data_utils.py +++ b/hezar/utils/data_utils.py @@ -1,17 +1,16 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional +import numpy as np +import torch from omegaconf import DictConfig from ..constants import PaddingType from .logging import Logger -if TYPE_CHECKING: - import torch - __all__ = [ "convert_batch_dict_dtype", "resolve_inputs_length_for_padding", @@ -166,18 +165,50 @@ def pad_batch_items( return padded_inputs -def shift_tokens_right(input_ids: "torch.Tensor", pad_token_id: int, decoder_start_token_id: int): +def shift_tokens_right( + token_ids: list[list[int]] | "torch.Tensor" | "np.ndarray", + pad_token_id: int, + decoder_start_token_id: int +): """ Shift input ids one token to the right. """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id + # Check if input is a list of lists + if isinstance(token_ids, list): + # Initialize shifted_input_ids with the same shape as input_ids + shifted_input_ids = [[0] * len(row) for row in token_ids] + + for i, row in enumerate(token_ids): + # Shift each row one token to the right + shifted_input_ids[i][1:] = row[:-1] + # Set the first token to decoder_start_token_id + shifted_input_ids[i][0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids[i] = [pad_token_id if token == -100 else token for token in shifted_input_ids[i]] + return shifted_input_ids + + # Check if input is a NumPy array + elif isinstance(token_ids, np.ndarray): + # Initialize shifted_input_ids with zeros and the same shape as input_ids + shifted_input_ids = np.zeros_like(token_ids) + shifted_input_ids[:, 1:] = token_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + # Check if input is a PyTorch tensor + elif isinstance(token_ids, torch.Tensor): + # Initialize shifted_input_ids with zeros and the same shape as input_ids + shifted_input_ids = token_ids.new_zeros(token_ids.shape) + shifted_input_ids[:, 1:] = token_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + # Replace any -100 values with pad_token_id + shifted_input_ids = shifted_input_ids.masked_fill(shifted_input_ids == -100, pad_token_id) + return shifted_input_ids - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids + else: + raise TypeError("Unsupported input type. Expected list, numpy array, or torch tensor.") def torch2numpy(*args): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 6a003cbe..9a2ae80c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -60,7 +60,7 @@ "config": { "max_size": 4, "max_length": 128, - "max_target_length": 32, + "labels_max_length": 32, } }, "model": {