Skip to content

Commit

Permalink
Whisper model support in Lite (#11464)
Browse files Browse the repository at this point in the history
* Data loader

* another dataset

* preprocessed audio dataset

Signed-off-by: Onur Yilmaz <[email protected]>

* seq2seq support

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* remove any update

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* fixing validation errors

Signed-off-by: Onur Yilmaz <[email protected]>

* Modify training step and tokenizer to achieve correct Whisper training

Signed-off-by: Piotr Żelasko <[email protected]>

* Apply isort and black reformatting

Signed-off-by: pzelasko <[email protected]>

* Moved files into speechlm collection

Signed-off-by: Onur Yilmaz <[email protected]>

* revert changes

Signed-off-by: Onur Yilmaz <[email protected]>

* create recipes folder

Signed-off-by: Onur Yilmaz <[email protected]>

* generalize forward

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* example update

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* address codeql reviews

Signed-off-by: Onur Yilmaz <[email protected]>

* remove examples

Signed-off-by: Onur Yilmaz <[email protected]>

---------

Signed-off-by: Onur Yilmaz <[email protected]>
Signed-off-by: oyilmaz-nvidia <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: pzelasko <[email protected]>
Co-authored-by: oyilmaz-nvidia <[email protected]>
Co-authored-by: Piotr Żelasko <[email protected]>
Co-authored-by: pzelasko <[email protected]>
  • Loading branch information
4 people authored Dec 24, 2024
1 parent 8aa4b60 commit 6224655
Show file tree
Hide file tree
Showing 9 changed files with 826 additions and 0 deletions.
129 changes: 129 additions & 0 deletions examples/speechlm/sft/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl
import torch
from lhotse.dataset.collation import collate_matrices, collate_vectors
from omegaconf import OmegaConf

from nemo import lightning as nl
from nemo.collections import speechlm
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq

torch.set_float32_matmul_precision("medium")


class LhotseHfNeMoDataset(torch.utils.data.Dataset):
def __init__(self, processor, tokenizer, decoder_mask_fill=-100):
super().__init__()
self.processor = processor
self.tokenizer = tokenizer
self.decoder_mask_fill = decoder_mask_fill

def __getitem__(self, cuts):
features = []
for cut in cuts:
audio = cut.load_audio()
features.append(
self.processor(
audio,
sampling_rate=cut.sampling_rate,
return_tensors="pt",
text=cut.supervisions[0].text,
)
)

input_features = collate_matrices(tensors=[f["input_features"].squeeze(0) for f in features])
labels = collate_vectors(tensors=[c.supervisions[0].tokens for c in cuts])
decoder_input_ids = labels[:, :-1]
decoder_input_ids = decoder_input_ids.masked_fill(
decoder_input_ids == self.decoder_mask_fill, self.tokenizer.pad_id
)
labels = labels[:, 1:].reshape(-1)

return {
"input_features": input_features,
"labels": labels,
"decoder_input_ids": decoder_input_ids,
}


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()

# Models can be one of the supported ones by AutoModelForSpeechSeq2Seq such as
# openai/whisper-large-v3 and facebook/s2t-small-librispeech-asr
parser.add_argument('--model', default='openai/whisper-large-v3')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--model-save-path', type=str, default=None)
args = parser.parse_args()

model = HFAutoModelForSpeechSeq2Seq(model_name=args.model)
model = model.to(torch.float)
processor = model.processor
tokenizer = AutoTokenizer(args.model, include_special_tokens=True)

config = OmegaConf.create(
{
"cuts_path": "/opt/checkpoints/lhotse/libri/libri-train-5.jsonl.gz",
"sample_rate": 16000,
"shuffle": True,
"num_workers": 2,
"batch_size": 4,
"shuffle_buffer_size": 100,
}
)

train_dataloader = get_lhotse_dataloader_from_config(
config,
global_rank=0,
world_size=1,
dataset=LhotseHfNeMoDataset(
processor=processor,
tokenizer=tokenizer,
),
tokenizer=tokenizer,
)

speechlm.api.finetune(
model=model,
data=train_dataloader,
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator=args.accelerator,
strategy=args.strategy,
precision="bf16-mixed",
log_every_n_steps=1,
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=0.5,
use_distributed_sampler=False,
callbacks=[],
logger=None,
),
optim=fdl.build(speechlm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
)

if args.model_save_path is not None:
model.save_pretrained(args.model_save_path)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
additional_special_tokens: Optional[List] = [],
use_fast: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
include_special_tokens: bool = False,
):
"""
Args:
Expand All @@ -63,6 +64,7 @@ def __init__(
unk_token: token to use for unknown tokens
additional_special_tokens: list of other tokens beside standard special tokens (bos, eos, pad, etc.). For example, sentinel tokens for T5 (<extra_id_0>, <extra_id_1>, etc.)
use_fast: whether to use fast HuggingFace tokenizer
include_special_tokens: when True, converting text to ids will include special tokens / prompt tokens (if any), yielding self.tokenizer(text).input_ids
"""
try:
# this logic deals with different huggingface tokenizers having different positional args
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}'
)

self.include_special_tokens = include_special_tokens
self.original_vocab_size = len(self.tokenizer)
special_tokens_dict = {}

Expand Down Expand Up @@ -220,6 +223,8 @@ def ids_to_tokens(self, ids):
return tokens

def text_to_ids(self, text):
if self.include_special_tokens:
return self.tokenizer(text).input_ids
tokens = self.text_to_tokens(text)
ids = self.tokens_to_ids(tokens)
return ids
Expand Down
38 changes: 38 additions & 0 deletions nemo/collections/speechlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq
from nemo.utils import logging

__all__ = [
"HFAutoModelForSpeechSeq2Seq",
]

try:
import nemo_run as run

from nemo.collections.llm.recipes import adam
from nemo.collections.speechlm.api import finetune, generate, pretrain, train, validate

__all__.extend(
[
"train",
"pretrain",
"validate",
"finetune",
"generate",
]
)
except ImportError as error:
logging.warning(f"Failed to import nemo.collections.speechlm.[api, recipes]: {error}")
Loading

0 comments on commit 6224655

Please sign in to comment.