Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add dataset processors #184

Merged
merged 38 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1b9969d
:sparkles: Add `dataset_processors` and `DatasetProcessor` base class
arxyzan Oct 18, 2024
7980c1e
:sparkles: Add `TextClassificationDatasetProcessor` to dataset proces…
arxyzan Oct 18, 2024
f29e9d2
:sparkles: Add `TextClassificationDatasetProcessor` to dataset proces…
arxyzan Oct 18, 2024
b86a2e0
Merge branch 'refs/heads/main' into datasets-map-processing
arxyzan Oct 19, 2024
0bef124
:sparkles: Improve dataset processor logic for handling batched mode
arxyzan Oct 19, 2024
687301a
:sparkles: Add sequence labeling dataset processor
arxyzan Oct 19, 2024
d41d9f1
:sparkles: Add sequence labeling dataset processor
arxyzan Oct 19, 2024
5664078
Merge remote-tracking branch 'origin/datasets-map-processing' into da…
arxyzan Oct 19, 2024
64e511d
:pencil2: Rename `max_target_length` -> `labels_max_length` for text …
arxyzan Oct 19, 2024
8f1f0b2
:sparkles: Add text summarization dataset processor
arxyzan Oct 19, 2024
9c6c0a4
:sparkles: Add OCR dataset processor
arxyzan Oct 19, 2024
18e67e8
:sparkles: Add speech recognition dataset processor
arxyzan Oct 19, 2024
bb412f4
:fire: Move all dataset processors to one file (`dataset_processors.py`)
arxyzan Oct 19, 2024
2318680
:sparkles: Add `dataset_processing_example.py`
arxyzan Oct 19, 2024
dae7d31
Merge branch 'refs/heads/main' into datasets-map-processing
arxyzan Oct 22, 2024
705f197
:sparkles: Add `return_tensors` option to dataset processors
arxyzan Oct 22, 2024
fc5c52d
:pencil2: Temp dataset processing example
arxyzan Oct 22, 2024
95ce482
Merge branch 'refs/heads/main' into datasets-map-processing
arxyzan Oct 25, 2024
79687c9
:bug: Fix speech recognition data collator bug
arxyzan Oct 25, 2024
b9a2efb
:bug: Fix OCR data collator bug
arxyzan Oct 25, 2024
7265263
:sparkles: Return non-batched output when input is not batched in `To…
arxyzan Oct 25, 2024
e04da6e
:pencil2: Minor renamings
arxyzan Oct 25, 2024
95b566c
:sparkles: Return list objects in datasets when tokenizing by default
arxyzan Oct 25, 2024
57870ed
:bug: Handle batched inputs better in Tokenizer call
arxyzan Oct 31, 2024
f37a80e
:fire: Remove unnecessary unbatching in data collators
arxyzan Oct 31, 2024
2c32b67
:pencil2: Minor changes in dataset processors
arxyzan Oct 31, 2024
4cdf79d
:fire: Deprecate setting `max_length`, `padding`, `truncation` in the…
arxyzan Oct 31, 2024
6e2e3fa
:adhesive_bandage: Make batch conversions cleaner in data collators
arxyzan Oct 31, 2024
a1bd4c2
:adhesive_bandage: Fix truncation issue when `max_length` is None in …
arxyzan Oct 31, 2024
6288c07
:adhesive_bandage: Fix some issues in dataset processors
arxyzan Oct 31, 2024
1e54f81
:adhesive_bandage: Return `decoder_attention_mask` instead of `attent…
arxyzan Oct 31, 2024
d1515a6
:pencil2: Fix issues in data collators
arxyzan Nov 7, 2024
f6ae426
:sparkles: Make `shift_tokens_right` compatible with all tensor types
arxyzan Nov 7, 2024
e14d106
:pencil2: Return unbatched pixel values in `ImageCaptioningDataset`
arxyzan Nov 7, 2024
2b147ff
:memo: Add dataset processor guides in the docs
arxyzan Nov 8, 2024
aa12363
:bug: Fix some data processor issues
arxyzan Nov 14, 2024
07ab703
:fire: Remove deprecated tokenizer config args
arxyzan Nov 14, 2024
4866756
Merge branch 'refs/heads/main' into datasets-map-processing
arxyzan Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions docs/guide/dataset_processors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Dataset Processors
Initially, Hezar's Trainer worked only with PyTorch Datasets (derived from `torch.utils.data.Dataset`) like all Hezar datasets classes
at `hezar.data.datasets`. Moving on, we also added support for any iterable as the dataset in Hezar's Trainer.

One really important type of datasets is 🤗 Datasets. The Trainer almost supported these type of datasets since day one,
but implementing the data pipelines must have been handled by the user. That's why Hezar (`v0.42.0>=`) added a new category of classes
called Dataset Processors. These classes are used as dataset map callables which has the following benefits:
- The same processing pipeline in the corresponding `hezar.data.Dataset` subclass is implemented as a map function.
For example, `SpeechRecognitionDatasetProcessor` corresponds to `SpeechRecognitionDataset`.
- Features like cacheing, multiprocessing, batch processing, etc. are now available since objects are of type `datasets.Dataset`.
- Other dataset processing pipelines from other codes feel like plug-and-play to work with Hezar's `Trainer`.

Now lets see an example demonstrating both cases:

**Classic 🤗Datasets**

Here we need to implement a map function that processes our samples. 🤗Datasets `map` function works on callables that
operate on either single or batched inputs. Below is an implementation for batched processing:
```python
from datasets import load_dataset, Audio
from hezar.preprocessors import Preprocessor


preprocesssor = Preprocessor.load("hezarai/whisper-small-fa")
feature_extractor = preprocesssor.audio_feature_extractor
tokenizer = preprocesssor.tokenizer

def batch_process_fn(data):
# Extract audio arrays and transcripts
audio_arrays = data["audio"] # Assuming audio arrays are stored under the "audio" key
transcripts = data["transcript"] # Assuming transcripts are stored under the "transcript" key

# Extract input features in batch
input_features = feature_extractor(
audio_arrays,
sampling_rate=16000,
return_tensors="np", # Return as numpy for compatibility with map
)["input_features"]

# Tokenize transcripts in batch
labels = tokenizer(
transcripts,
padding="max_length",
max_length=448,
return_tensors="np",
)

# Add processed data to the dictionary
data["input_features"] = input_features
data["labels"] = labels["input_ids"]
data["attention_mask"] = labels["attention_mask"]

return data

dataset = load_dataset("hezarai/common-voice-13-fa", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.select_columns(["sentence", "audio"])
# Apply the function to the dataset using map
processed_dataset = dataset.map(batch_process_fn, batched=True)
processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"])
print(processed_dataset[0])
```

**Hezar Dataset Processors**

Here's an equivalent code using the `SpeechRecognitionDatasetProcessor` that has implemented the same map function as a
callable (`SpeechRecognitionDatasetProcessor.__call__()`) that works with both single and batched inputs out of the box!
```python
from datasets import load_dataset, Audio

from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator
from hezar.preprocessors import Preprocessor

dataset = load_dataset("hezarai/common-voice-13-fa", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.select_columns(["sentence", "audio"])

preprocesssor = Preprocessor.load("hezarai/whisper-small-fa")

dataset_processor = SpeechRecognitionDatasetProcessor(
tokenizer=preprocesssor.tokenizer,
feature_extractor=preprocesssor.audio_feature_extractor,
transcript_column="sentence",
audio_array_padding="max_length",
)
data_collator = SpeechRecognitionDataCollator(
feature_extractor=preprocesssor.audio_feature_extractor,
tokenizer=preprocesssor.tokenizer,
labels_padding="max_length",
labels_max_length=256,
)
processed_dataset = dataset.map(
dataset_processor,
batched=True,
batch_size=100,
desc="Processing dataset..."
)
processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"])
print(processed_dataset[0])
```

## How Dataset Processors Work
Dataset processors classes are callable classes that receive dataset rows/batches and process them when used as a map function
with `datasets.Dataset.map()`. Here are the current supported dataset processors:
- `ImageCaptioningDatasetProcessor`
- `OCRDatasetProcessor`
- `SequenceLabelingDatasetProcessor`
- `SpeechRecognitionDatasetProcessor`
- `TextClassificationDatasetProcessor`
- `TextSummarizationDatasetProcessor`

All the above classes inherit from the base `DatasetProcessor` class and must implement the following two methods:
- `process_single(data, **kwargs)`
- `process_batch(data, **kwargs)`

The main `__call__()` method is implemented in the base class to figure out if the input `data` is a single row or a batch.


## A Training Example
Let's see how we can use a dataset processor to load and process a Hub dataset for speech recognition and train a Whisper model.

```python
from datasets import load_dataset, Audio

from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator
from hezar.preprocessors import Preprocessor
from hezar.trainer import Trainer, TrainerConfig
from hezar.models import Model

dataset = load_dataset("hezarai/common-voice-13-fa", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.select_columns(["sentence", "audio"])

base_model_path = "hezarai/whisper-small"
preprocesssor = Preprocessor.load(base_model_path)

dataset_processor = SpeechRecognitionDatasetProcessor(
tokenizer=preprocesssor.tokenizer,
feature_extractor=preprocesssor.audio_feature_extractor,
transcript_column="sentence",
audio_array_padding="max_length",
)
# This is the same data collator used in `SpeechRecognitionDataset`
data_collator = SpeechRecognitionDataCollator(
feature_extractor=preprocesssor.audio_feature_extractor,
tokenizer=preprocesssor.tokenizer,
labels_padding="max_length",
labels_max_length=256,
)
processed_dataset = dataset.map(
dataset_processor,
batched=True,
batch_size=100,
desc="Processing dataset..."
)
# Select needed columns for training
processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"])
# Split dataset for train/evaluation
processed_dataset = processed_dataset.train_test_split(test_size=0.1)

model = Model.load(base_model_path)

train_config = TrainerConfig(
output_dir="whisper-small-fa-commonvoice",
task="speech_recognition",
init_weights_from=base_model_path,
mixed_precision="bf16",
gradient_accumulation_steps=8,
batch_size=4,
num_epochs=5,
metrics=["cer", "wer"],
)

trainer = Trainer(
config=train_config,
model=model,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["test"],
data_collator=data_collator,
)
trainer.train()
```

## Wrap-up
Dataset processors are simple, yet powerful callable classes to be used for dataset processing using the `.map()` function
in 🤗Datasets. This integration means that all 🤗Dataset features are unlocked when working with Hezar!
1 change: 1 addition & 0 deletions docs/guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Welcome to the developer guide section where you can take a deeper dive into the

hezar_architecture.md
models_advanced.md
dataset_processors.md
trainer_in_depth.md
advanced_training.md
```
91 changes: 90 additions & 1 deletion docs/tutorial/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,96 @@ class ImageCaptioningDataset(Dataset):
pass
```

## Loading Regular HF Datasets
## Loading with 🤗Datasets
You can load all Hezar datasets using the 🤗Datasets library too. Doing so has the following pros and cons:

**Pros**:
- You can work with any dataset on the Hub and use it easily with Hezar.
- You can leverage multiprocessing and batch processing feature of such datasets (which is not available using torch datasets).
- You can leverage mapped dataset caching provided by 🤗Datasets.
- No integration needed for your old data pipeline codes to make them work with Hezar.

**Cons**:
- You have to take care of the data processing yourself unless one of the dataset processors at `hezar.data.dataset_processors` suits your needs.

### Using Hezar's Dataset Processors
In order to replicate the same behavior of the `hezar.data.Dataset` classes for 🤗 loaded dataset, Hezar also implements
a group of dataset processor classes so that you can use them to map the loaded datasets and get the same processed instances
when iterating over your loaded 🤗 datasets.

Below is a comparison of both methods, using Hezar's torch compatible datasets vs 🤗 Datasets:

**Loading and Processing with Hezar**

```python
from torch.utils.data import DataLoader
from hezar.data import SpeechRecognitionDataset, SpeechRecognitionDatasetConfig

# You can also use the regular `Dataset.load("hezarai/common-voice-13-fa")`, below is for better understanding.
dataset = SpeechRecognitionDataset(
SpeechRecognitionDatasetConfig(
path="hezarai/common-voice-13-fa",
sampling_rate=16000,
audio_file_path_column="path",
audio_column="audio",
audio_array_column="array",
transcript_column="sentence",
),
split="train",
preprocessor="hezarai/whisper-small-fa",
)

loader = DataLoader(dataset, batch_size=16, collate_fn=dataset.data_collator)
itr = iter(loader)
print(next(itr))
```

**Loading and Processing with 🤗Datasets and Hezar Dataset Processors**

```python
from datasets import load_dataset, Audio
from torch.utils.data import DataLoader

from hezar.data import SpeechRecognitionDatasetProcessor, SpeechRecognitionDataCollator
from hezar.preprocessors import Preprocessor

dataset = load_dataset("hezarai/common-voice-13-fa", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.select_columns(["sentence", "audio"])
preprocesssor = Preprocessor.load("hezarai/whisper-small-fa")

dataset_processor = SpeechRecognitionDatasetProcessor(
tokenizer=preprocesssor.tokenizer,
feature_extractor=preprocesssor.audio_feature_extractor,
transcript_column="sentence",
audio_array_padding="max_length",
)
data_collator = SpeechRecognitionDataCollator(
feature_extractor=preprocesssor.audio_feature_extractor,
tokenizer=preprocesssor.tokenizer,
labels_padding="max_length",
labels_max_length=256,
)
processed_dataset = dataset.map(
dataset_processor,
batched=True,
batch_size=100,
desc="Processing dataset..."
)
processed_dataset = processed_dataset.select_columns(["input_features", "labels", "attention_mask"])
data_loader = DataLoader(processed_dataset, batch_size=16, collate_fn=data_collator)
x = next(iter(data_loader))
print(x)
```
Both codes above, give you the same kind of results. Although using dataset processors is more complicated, but it gives
you more control and better integration with typical data pipelines used nowadays.

```{note}
You don't necessarily need to use the dataset processor classes in Hezar. They are just there to implement the same
procedures and reproduce the same results. This means that any code that uses 🤗 Datasets will work with Hezar's Trainer.
```

### Loading Regular HF Datasets
All the current datasets provided in Hezar's Hugging Face, have the `dataset_config.yaml` in their repos which does not
exist for regular HF datasets. If you need to load such datasets (that have the correct structure and fields) in Hezar
using the `Dataset.load()` method, you have to provide the dataset config manually.
Expand Down
Loading
Loading