diff --git a/examples/speechlm/sft/hf.py b/examples/speechlm/sft/hf.py new file mode 100755 index 000000000000..96e785dac97f --- /dev/null +++ b/examples/speechlm/sft/hf.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import torch +from lhotse.dataset.collation import collate_matrices, collate_vectors +from omegaconf import OmegaConf + +from nemo import lightning as nl +from nemo.collections import speechlm +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq + +torch.set_float32_matmul_precision("medium") + + +class LhotseHfNeMoDataset(torch.utils.data.Dataset): + def __init__(self, processor, tokenizer, decoder_mask_fill=-100): + super().__init__() + self.processor = processor + self.tokenizer = tokenizer + self.decoder_mask_fill = decoder_mask_fill + + def __getitem__(self, cuts): + features = [] + for cut in cuts: + audio = cut.load_audio() + features.append( + self.processor( + audio, + sampling_rate=cut.sampling_rate, + return_tensors="pt", + text=cut.supervisions[0].text, + ) + ) + + input_features = collate_matrices(tensors=[f["input_features"].squeeze(0) for f in features]) + labels = collate_vectors(tensors=[c.supervisions[0].tokens for c in cuts]) + decoder_input_ids = labels[:, :-1] + decoder_input_ids = decoder_input_ids.masked_fill( + decoder_input_ids == self.decoder_mask_fill, self.tokenizer.pad_id + ) + labels = labels[:, 1:].reshape(-1) + + return { + "input_features": input_features, + "labels": labels, + "decoder_input_ids": decoder_input_ids, + } + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + # Models can be one of the supported ones by AutoModelForSpeechSeq2Seq such as + # openai/whisper-large-v3 and facebook/s2t-small-librispeech-asr + parser.add_argument('--model', default='openai/whisper-large-v3') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--model-save-path', type=str, default=None) + args = parser.parse_args() + + model = HFAutoModelForSpeechSeq2Seq(model_name=args.model) + model = model.to(torch.float) + processor = model.processor + tokenizer = AutoTokenizer(args.model, include_special_tokens=True) + + config = OmegaConf.create( + { + "cuts_path": "/opt/checkpoints/lhotse/libri/libri-train-5.jsonl.gz", + "sample_rate": 16000, + "shuffle": True, + "num_workers": 2, + "batch_size": 4, + "shuffle_buffer_size": 100, + } + ) + + train_dataloader = get_lhotse_dataloader_from_config( + config, + global_rank=0, + world_size=1, + dataset=LhotseHfNeMoDataset( + processor=processor, + tokenizer=tokenizer, + ), + tokenizer=tokenizer, + ) + + speechlm.api.finetune( + model=model, + data=train_dataloader, + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + precision="bf16-mixed", + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=0.5, + use_distributed_sampler=False, + callbacks=[], + logger=None, + ), + optim=fdl.build(speechlm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + ) + + if args.model_save_path is not None: + model.save_pretrained(args.model_save_path) diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index 14da2d13a030..54cf95296d3d 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -46,6 +46,7 @@ def __init__( additional_special_tokens: Optional[List] = [], use_fast: Optional[bool] = False, trust_remote_code: Optional[bool] = False, + include_special_tokens: bool = False, ): """ Args: @@ -63,6 +64,7 @@ def __init__( unk_token: token to use for unknown tokens additional_special_tokens: list of other tokens beside standard special tokens (bos, eos, pad, etc.). For example, sentinel tokens for T5 (, , etc.) use_fast: whether to use fast HuggingFace tokenizer + include_special_tokens: when True, converting text to ids will include special tokens / prompt tokens (if any), yielding self.tokenizer(text).input_ids """ try: # this logic deals with different huggingface tokenizers having different positional args @@ -92,6 +94,7 @@ def __init__( f'Unable to instantiate HuggingFace AUTOTOKENIZER for {pretrained_model_name}. Exception: {e}' ) + self.include_special_tokens = include_special_tokens self.original_vocab_size = len(self.tokenizer) special_tokens_dict = {} @@ -220,6 +223,8 @@ def ids_to_tokens(self, ids): return tokens def text_to_ids(self, text): + if self.include_special_tokens: + return self.tokenizer(text).input_ids tokens = self.text_to_tokens(text) ids = self.tokens_to_ids(tokens) return ids diff --git a/nemo/collections/speechlm/__init__.py b/nemo/collections/speechlm/__init__.py new file mode 100755 index 000000000000..2b19e0be88fd --- /dev/null +++ b/nemo/collections/speechlm/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq +from nemo.utils import logging + +__all__ = [ + "HFAutoModelForSpeechSeq2Seq", +] + +try: + import nemo_run as run + + from nemo.collections.llm.recipes import adam + from nemo.collections.speechlm.api import finetune, generate, pretrain, train, validate + + __all__.extend( + [ + "train", + "pretrain", + "validate", + "finetune", + "generate", + ] + ) +except ImportError as error: + logging.warning(f"Failed to import nemo.collections.speechlm.[api, recipes]: {error}") diff --git a/nemo/collections/speechlm/api.py b/nemo/collections/speechlm/api.py new file mode 100644 index 000000000000..2342da6eb45c --- /dev/null +++ b/nemo/collections/speechlm/api.py @@ -0,0 +1,442 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from copy import deepcopy +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import lightning.pytorch as pl +import nemo_run as run + +from typing_extensions import Annotated + +import nemo.lightning as nl +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, +) +from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform +from nemo.utils import logging + +TokenizerType = Any + + +@run.cli.entrypoint(namespace="speechlm") +def train( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, + # TODO: Fix export export: Optional[str] = None, +) -> Path: + """ + Trains a model using the specified data and trainer, with optional tokenizer, source, and export. + + Args: + model (pl.LightningModule): The model to be trained. + data (pl.LightningDataModule): The data module containing training data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + from the model will be used. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. + export (Optional[str]): Filename to save the exported checkpoint after training. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. + + Returns + ------- + Path: The directory path where training artifacts are saved. + + Examples + -------- + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.train(model, data, trainer, tokenizer="data") + PosixPath('/path/to/log_dir') + """ + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + trainer.fit(model, data) + + return app_state.exp_dir + + +@run.cli.entrypoint(namespace="speechlm") +def pretrain( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, +) -> Path: + """ + Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization. + + This function is a wrapper around the `train` function, specifically configured for pretraining tasks. + Note, by default it will use the tokenizer from the model. + + Args: + model (pl.LightningModule): The model to be pretrained. + data (pl.LightningDataModule): The data module containing pretraining data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + + Returns: + Path: The directory path where pretraining artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.PretrainingDataModule(paths=[...], seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.pretrain(model, data, trainer) + PosixPath('/path/to/log_dir') + """ + _validate_config(model, data, trainer, log=log, resume=resume, optim=optim) + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="data", + ) + + +@run.cli.entrypoint(namespace="speechlm") +def finetune( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + peft: Optional[Union[PEFT, ModelTransform, Callable]] = None, +) -> Path: + """ + Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT. + + Note, by default it will use the tokenizer from the model. + + Args: + model (pl.LightningModule): The model to be finetuned. + data (pl.LightningDataModule): The data module containing finetuning data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied. + + Returns: + Path: The directory path where finetuning artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()]) + PosixPath('/path/to/log_dir') + """ + + _validate_config(model, data, trainer, log=log, resume=resume, optim=optim, model_transform=peft) + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="model", + model_transform=peft, + ) + + +@run.cli.entrypoint(namespace="speechlm") +def validate( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, +) -> Path: + """ + Validates a model using the specified data and trainer, with optional logging, resuming, and model transformations. + + Args: + model (pl.LightningModule): The model to be validated. + data (pl.LightningDataModule): The data module containing validation data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume from a checkpoint for validation. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + from the model will be used. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. + + Returns: + Path: The directory path where validation artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.validate(model, data, trainer, tokenizer="data") + PosixPath('/path/to/log_dir') + """ + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + trainer.validate(model, data) + + return app_state.exp_dir + + +def evaluate(): + """ + Evaluates NeMo SpeechLM model. + """ + raise NotImplementedError("This function will be implemented later") + + +@run.cli.entrypoint(name="generate", namespace="speechlm") +def generate(): + """ + Generates text using a NeMo Speech model. + """ + raise NotImplementedError("This function will be implemented later") + + +def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None: + if tokenizer == "data": + _set_with_io(model, "tokenizer", data.tokenizer) + elif tokenizer == "model": + _set_with_io(data, "tokenizer", model.tokenizer) + else: + try: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + if isinstance(tokenizer, TokenizerSpec): + _set_with_io(model, "tokenizer", tokenizer) + _set_with_io(data, "tokenizer", tokenizer) + else: + raise ValueError(f"Expected TokenizerSpec or 'data' or 'model', got: {tokenizer}") + except ImportError: + raise ValueError("TokenizerSpec is not available") + + +def _setup( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Optional[NeMoLogger], + resume: Optional[AutoResume], + optim: Optional[OptimizerModule], + tokenizer: Optional[TokenizerType], + model_transform: Optional[Union[PEFT, ModelTransform, Callable]], +) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) + _log = log or NeMoLogger() + if resume and isinstance(model_transform, PEFT) and _log.ckpt: + logging.info("Disabling try_restore_best_ckpt restoration for adapters") + _log.ckpt.try_restore_best_ckpt = False + + app_state = _log.setup( + trainer, + resume_if_exists=getattr(resume, "resume_if_exists", False), + task_config=getattr(train, "__io__", None), + ) + if resume is not None: + resume.setup(trainer, model) + + if optim: + optim.connect(model) + if tokenizer: # TODO: Improve this + _use_tokenizer(model, data, tokenizer) + + if model_transform: + _set_with_io(model, "model_transform", model_transform) + + # Add ModelTransform callback to Trainer if needed + if getattr(model, "model_transform", None): + if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks): + if isinstance(model_transform, ModelTransform): + trainer.callbacks.append(model_transform) + else: + trainer.callbacks.append(ModelTransform()) + + return app_state + + +def _set_with_io(obj, attr, value): + setattr(obj, attr, value) + if hasattr(obj, "__io__") and hasattr(value, "__io__"): + setattr(obj.__io__, attr, deepcopy(value.__io__)) + + +def _validate_config( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Optional[NeMoLogger] = None, + resume: Optional[AutoResume] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, +) -> None: + + ## Model validation + if hasattr(model, "config"): + assert getattr(model.config, "seq_length", 1) > 0 + assert getattr(model.config, "max_position_embeddings", 1) > 0 + assert model.config.num_layers > 0 + assert model.config.hidden_size > 0 + assert model.config.num_attention_heads > 0 + assert model.config.ffn_hidden_size > 0 + + if hasattr(model.config, "seq_length"): + if getattr(model.config, "max_position_embeddings", None) is not None: + assert model.config.seq_length <= model.config.max_position_embeddings + else: + assert not isinstance(trainer.strategy, nl.MegatronStrategy), "Expected model.config to exist" + + ## Data validation + if hasattr(data, 'micro_batch_size'): + assert data.micro_batch_size > 0 + if hasattr(data, 'global_batch_size'): + assert data.global_batch_size > 0 + if hasattr(data, 'seq_length'): + assert data.seq_length > 0 + + if hasattr(data, 'micro_batch_size') and hasattr(data, 'global_batch_size'): + assert ( + data.global_batch_size % data.micro_batch_size == 0 + ), "Global batch size must be divisible by micro batch size in data module." + + ## Trainer validation + + # MegatronStrategy validation + if isinstance(trainer.strategy, nl.MegatronStrategy): + # Basic validation + assert trainer.strategy.tensor_model_parallel_size > 0 + assert trainer.strategy.pipeline_model_parallel_size > 0 + assert trainer.strategy.context_parallel_size > 0 + + # DP validation + assert (trainer.num_devices * trainer.num_nodes) % ( + trainer.strategy.tensor_model_parallel_size + * trainer.strategy.pipeline_model_parallel_size + * trainer.strategy.context_parallel_size + ) == 0, "Number of GPUs must be divisible by the product of all parallelism sizes for data parallel." + + assert ( + data.global_batch_size + % ( + data.micro_batch_size + * ( + (trainer.num_devices * trainer.num_nodes) + / ( + trainer.strategy.tensor_model_parallel_size + * trainer.strategy.pipeline_model_parallel_size + * trainer.strategy.context_parallel_size + ) + ) + ) + == 0 + ), "Global batch size must be divisible by the product of micro batch size and data parallel size" + + # TP/SP validation + if trainer.strategy.tensor_model_parallel_size == 1: + if trainer.strategy.sequence_parallel == True: + warnings.warn("Disabling sequence parallelism because tensor model parallelism is disabled") + trainer.strategy.sequence_parallel = False + + # PP/VP validation + if trainer.strategy.pipeline_model_parallel_size > 1: + assert ( + trainer.strategy.pipeline_dtype is not None + ), "pipeline_dtype must be set if pipeline model parallelism is enabled" + else: + if trainer.strategy.virtual_pipeline_model_parallel_size is not None: + warnings.warn("Disabling virtual pipeline parallelism because pipeline model parallelism is disabled") + trainer.strategy.virtual_pipeline_model_parallel_size = None + if trainer.strategy.pipeline_dtype is not None: + warnings.warn("Setting pipeline dtype to None because pipeline model parallelism is disabled") + trainer.strategy.pipeline_dtype = None + + # CP validation + if trainer.strategy.context_parallel_size > 1: + if hasattr(model, "config"): + if model.config.seq_length is not None: + assert ( + model.config.seq_length % (trainer.strategy.context_parallel_size * 2) == 0 + ), 'Sequence length must be divisible by 2 * context parallel size if context parallel is used.' + + # EP validation + if trainer.strategy.expert_model_parallel_size > 1: + if hasattr(model, "config"): + assert ( + model.config.num_moe_experts is not None + ), "num_experts must be non None to use expert model parallelism" + assert ( + model.config.num_moe_experts % trainer.strategy.expert_model_parallel_size == 0 + ), "Number of experts should be a multiple of expert model parallel_size." diff --git a/nemo/collections/speechlm/models/__init__.py b/nemo/collections/speechlm/models/__init__.py new file mode 100644 index 000000000000..a7d4e02cb4ca --- /dev/null +++ b/nemo/collections/speechlm/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm.models.hf_auto_model_for_speech_seq2seq import HFAutoModelForSpeechSeq2Seq diff --git a/nemo/collections/speechlm/models/hf_auto_model_for_speech_seq2seq.py b/nemo/collections/speechlm/models/hf_auto_model_for_speech_seq2seq.py new file mode 100644 index 000000000000..a039edc66a39 --- /dev/null +++ b/nemo/collections/speechlm/models/hf_auto_model_for_speech_seq2seq.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lightning.pytorch as pl +import torch +import torch.nn.functional as F +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm import fn +from nemo.lightning import io +from nemo.utils import logging + + +def masked_cross_entropy(logits, targets, mask=None): + if mask is not None: + loss = F.cross_entropy(logits, targets, reduction='none') + return torch.mean(loss[mask == 1]) + else: + return F.cross_entropy(logits, targets) + + +class HFAutoModelForSpeechSeq2Seq(pl.LightningModule, io.IOMixin, fn.FNMixin): + def __init__( + self, + model_name='gpt2', + load_pretrained_weights=True, + tokenizer=None, + loss_fn=masked_cross_entropy, + model_transform=None, + model_accelerator=None, + trust_remote_code=False, + ): + super().__init__() + self.save_hyperparameters() + self.model_name = model_name + self._tokenizer = None + self._processor = None + self.model = None + self.loss_fn = loss_fn + self.load_pretrained_weights = load_pretrained_weights + self.is_hf_model = True + self.model_transform = model_transform + self.model_accelerator = model_accelerator + self.trust_remote_code = trust_remote_code + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = AutoTokenizer( + self.model_name, include_special_tokens=True, trust_remote_code=self.trust_remote_code + ) + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value): + assert self._tokenizer is None + self._tokenizer = value + + @property + def processor(self): + if self._processor is None: + self._processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) + return self._processor + + @staticmethod + def configure_tokenizer(model_name): + return AutoProcessor.from_pretrained(model_name).tokenizer + + def configure_model(self, train=True): + # create all your layers here + if self.model is None: + if self.load_pretrained_weights: + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + trust_remote_code=self.trust_remote_code, + use_safetensors=True, + ) + else: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) + self.model = AutoModelForSpeechSeq2Seq.from_config(config, trust_remote_code=self.trust_remote_code) + + if train: + self.model.train() + + def forward(self, input_features, decoder_input_ids, attention_mask=None): + return self.model( + input_features=input_features.to(self.model.device), + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + ) + + def training_step(self, batch): + outputs = self.forward(input_features=batch["input_features"], decoder_input_ids=batch["decoder_input_ids"]) + loss_mask = batch.get('loss_mask', None) + if loss_mask is not None: + loss_mask = loss_mask.to(self.model.device).view(-1) + n_cls = outputs.logits.shape[-1] + logits = outputs.logits.view(-1, n_cls) + loss = self.loss_fn(logits, batch["labels"], loss_mask) + + self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch): + output = self.forward(input_features=batch["input_features"], decoder_input_ids=batch["decoder_input_ids"]) + loss = output.loss + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + + def save_pretrained(self, path): + assert self.model is not None, "Model has to be created first." + self.model.save_pretrained(path) + if self._tokenizer is not None: + self._tokenizer.save_pretrained(path) + else: + logging.warning("A tokenizer wasn't created before to save.") + + if self._processor is not None: + self._processor.save_pretrained(path) + else: + logging.warning("A processor wasn't created before to save.") diff --git a/nemo/collections/speechlm/recipes/__init__.py b/nemo/collections/speechlm/recipes/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/speechlm/recipes/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm/recipes/optim/__init__.py b/nemo/collections/speechlm/recipes/optim/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/speechlm/recipes/optim/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/speechlm/recipes/optim/adam.py b/nemo/collections/speechlm/recipes/optim/adam.py new file mode 100644 index 000000000000..777c3978c3e0 --- /dev/null +++ b/nemo/collections/speechlm/recipes/optim/adam.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import nemo_run as run + +from nemo.lightning.pytorch.optim import PytorchOptimizerModule + + +@run.cli.factory +def pytorch_adam_with_flat_lr( + lr: float = 1e-5, +) -> run.Config[PytorchOptimizerModule]: + from torch.optim import Adam + + return run.Config( + PytorchOptimizerModule, + optimizer_fn=run.Partial( + Adam, + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-8, + ), + )