diff --git a/examples/llm/peft/automodel.py b/examples/llm/peft/automodel.py index d9cb6799d3e7..b209dac998ef 100755 --- a/examples/llm/peft/automodel.py +++ b/examples/llm/peft/automodel.py @@ -21,7 +21,6 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing -from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform # Run this example with torchrun, for example: @@ -171,9 +170,11 @@ def main(): ) callbacks = [] + jit_config = None if args.use_torch_jit: - jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False) - callbacks = [JitTransform(jit_config)] + from nemo.automodel.compiler import TorchCompileConfig + + jit_config = TorchCompileConfig(kwargs={'dynamic': True}) if args.use_te_optimizer: # Use TE optimizer @@ -195,6 +196,7 @@ def main(): use_liger_kernel=args.liger, load_in_4bit=args.load_in_4bit, enable_grad_ckpt=args.enable_grad_ckpt, + compiler_config=jit_config, ) strategy = make_strategy(args.strategy, model, args.devices, args.num_nodes, True, args.enable_cpu_offload) diff --git a/examples/llm/sft/automodel.py b/examples/llm/sft/automodel.py index 28f1280d3238..bb4dcb75981c 100755 --- a/examples/llm/sft/automodel.py +++ b/examples/llm/sft/automodel.py @@ -21,7 +21,6 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing -from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform # Run this example with torchrun, for example: # torchrun --nproc-per-node=8 \ @@ -169,9 +168,11 @@ def main(): ) callbacks = [] + jit_config = None if args.use_torch_jit: - jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False) - callbacks = [JitTransform(jit_config)] + from nemo.automodel.compiler import TorchCompileConfig + + jit_config = TorchCompileConfig(kwargs={'dynamic': True}) if args.use_te_optimizer: # Use TE optimizer @@ -192,6 +193,7 @@ def main(): trust_remote_code=args.trust_remote_code, use_liger_kernel=args.liger, enable_grad_ckpt=args.enable_grad_ckpt, + compiler_config=jit_config, ) strategy = make_strategy(args.strategy, model, args.devices, args.num_nodes, False, args.enable_cpu_offload) diff --git a/nemo/automodel/compiler/__init__.py b/nemo/automodel/compiler/__init__.py new file mode 100644 index 000000000000..dd5ad2bca38d --- /dev/null +++ b/nemo/automodel/compiler/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# VISION +from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig +from nemo.automodel.compiler.module_compiler import compile_module, compile_module_from_config +from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector, listify + +__all__ = [ + "TorchCompileConfig", + "ThunderConfig", + "compile_module", + "compile_module_from_config", + "extract_module_attr_name", + "listify", + "get_modules_from_selector", +] diff --git a/nemo/automodel/compiler/configs.py b/nemo/automodel/compiler/configs.py new file mode 100644 index 000000000000..4450a0fe542f --- /dev/null +++ b/nemo/automodel/compiler/configs.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + + +@dataclass +class TorchCompileConfig: + """Config for torch.compile + Options: + - module_selector (str): reg-exp to match modules to compile, useful for multi-trunk + models where you want to apply it on one of them only. If empty will apply transform to root + module. + - apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa. + - kwargs (dict): kwargs to pass to torch.compile. + """ + + module_selector: str = '' + apply_pre_wrap: bool = True + kwargs: dict = field(default_factory=dict) + + +@dataclass +class ThunderConfig: + """Config for Thunder + Options: + - module_selector (str): reg-exp to match modules to compile, useful for multi-trunk + models where you want to apply it on one of them only. If empty will apply transform to root + module. + - apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa. + - kwargs (dict): kwargs to pass to thunder, currently unused. + - profile (bool): toggle for thunder's profiler. + """ + + module_selector: str = '' + apply_pre_wrap: bool = True + kwargs: dict = field(default_factory=dict) + profile: bool = False diff --git a/nemo/automodel/compiler/module_compiler.py b/nemo/automodel/compiler/module_compiler.py new file mode 100644 index 000000000000..69e4e9a5b87e --- /dev/null +++ b/nemo/automodel/compiler/module_compiler.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig +from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector + + +def compile_module(config, module): + """Jit-compiles an nn.Module + + Args: + config (TorchCompileConfig, ThunderConfig): compiler config + module (nn.Module): the module to be compiled + + Returns: + nn.Module: the (potentially) compiled module + """ + if isinstance(config, TorchCompileConfig): + module.compile(**(config.kwargs or {})) + elif isinstance(config, ThunderConfig): + import thunder + import thunder.dynamo + from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform + + # With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just + # consists of `call_function` nodes only and no `call_module` node. + # This is the default setting in PyTorch 2.5 onwards + # (see https://github.com/pytorch/pytorch/pull/131275) + torch._dynamo.config.inline_inbuilt_nn_modules = True + + xforms: list = [NvtxProfileTransform()] if config.profile else [] + module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms)) + else: + raise ValueError("Expected config to be TorchCompileConfig or ThunderConfig") + + +def compile_module_from_config(config, module) -> None: + """Jit-compiles the model at the start of the epoch. + While other events such as on_train_start are more suitable, we use on_train_epoch_start + since that is what is used in peft (we want to jit after adding the adapters). + + Args: + module (nn.Module): the nn.Module to compile. + """ + if config is None: + return + if not isinstance(config, (TorchCompileConfig, ThunderConfig)): + return + + attr_name = extract_module_attr_name(module) + model = getattr(module, attr_name) + + if getattr(module, '_compiled', False) == True: + return + + # TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions + # and trigger the compile if anyone matches, instead of iterating over all O(N^2). + compiled = False + for module in get_modules_from_selector(model, config.module_selector): + compile_module(config, module) + compiled = True + + setattr(module, '_compiled', compiled) diff --git a/nemo/automodel/compiler/utils.py b/nemo/automodel/compiler/utils.py new file mode 100644 index 000000000000..81e3bfd7370e --- /dev/null +++ b/nemo/automodel/compiler/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import torch +import torch.nn as nn + + +def extract_module_attr_name(pl_module: "pl.LightningModule") -> str: + """Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail. + + Args: + pl_module (pl.LightningModule): the LightningModule used in training. + + Raises: + ValueError: if the pl_module has neither a .mdoel or .module + + Returns: + str: the attr-name of the nn.Module + """ + if hasattr(pl_module, 'module'): + return 'module' + elif hasattr(pl_module, 'model'): + return 'model' + elif isinstance(pl_module, nn.Module): + return pl_module + else: + raise ValueError("Expected lightning_module to have a .model or .module attr.") + + +def listify(x): + """Wraps input in a list, if not already a list. + + Args: + x (Anything): the input, can be anything. + + Returns: + Anything | list(Anything): Anything (if it's already a list) o/w list(Anything) + """ + if not isinstance(x, list): + return [x] + return x + + +def get_modules_from_selector(model, module_selector): + """Iterator over model's modules whose FQN match the module_selector. + + Args: + model (nn.Module): the model to iterate over. + module_selector (str): module selector, if empty or '*' will return the whole model. If + there's an asterisk in the name will match it as a regexp. + + Raises: + AttributeError: if the user provides an invalid selector. + AttributeError: if user's selector selects a non-nn.Module attribute. + + Yields: + Iterator(nn.Module): iterator over modules whose FQN matches module_selector + """ + if module_selector is None or module_selector == '' or module_selector == '*': + yield model + return + + assert isinstance(module_selector, str), module_selector + atoms: List[str] = module_selector.split('.') + tmp = model + + for i, item in enumerate(atoms): + if '*' in item: + # handle wildcard selector + # TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv + for name, module in tmp.named_children(): + if re.match(item.replace('*', '.*'), name): + yield module + return + + if not hasattr(tmp, item): + raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`") + tmp = getattr(tmp, item) + + if not isinstance(tmp, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + + yield tmp diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 04a5e428c008..3c62b9556e6e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -49,7 +49,7 @@ ) from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.ckpt_utils import ckpt_to_context_subdir -from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform +from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging from nemo.utils.get_rank import is_global_rank_zero @@ -1079,14 +1079,6 @@ def _setup( trainer.callbacks.append(model_transform) else: trainer.callbacks.append(ModelTransform()) - # Move jit callback at the end ensure it's applied on top of any model transformations (peft) - jit_cb = None - for i, cb in enumerate(trainer.callbacks): - if isinstance(cb, JitTransform): - assert jit_cb is None - jit_cb = trainer.callbacks.pop(i) - if jit_cb is not None: - trainer.callbacks.append(jit_cb) return app_state diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index f815b44226d7..3b541573f7de 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -51,6 +51,7 @@ def __init__( use_liger_kernel=False, enable_grad_ckpt=False, device_map="cpu", + compiler_config=None, ): """ Initialize the HFAutoModelForCausalLM. diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index b3a3074f4992..031f027e63b2 100755 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -15,7 +15,6 @@ from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker from nemo.lightning.pytorch.callbacks.debugging import ParameterDebugger from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback -from nemo.lightning.pytorch.callbacks.jit_transform import JitConfig, JitTransform from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback from nemo.lightning.pytorch.callbacks.model_callback import ModelCallback from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint @@ -39,6 +38,4 @@ "GarbageCollectionCallback", "ParameterDebugger", "ModelCallback", - "JitTransform", - "JitConfig", ] diff --git a/nemo/lightning/pytorch/callbacks/jit_transform.py b/nemo/lightning/pytorch/callbacks/jit_transform.py deleted file mode 100644 index 33e76555f65d..000000000000 --- a/nemo/lightning/pytorch/callbacks/jit_transform.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from dataclasses import dataclass, field - -import torch -from lightning.pytorch.callbacks.callback import Callback - -from nemo.lightning.io.mixin import IOMixin - - -def extract_module_attr_name(pl_module: "pl.LightningModule") -> str: - """Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail. - - Args: - pl_module (pl.LightningModule): the LightningModule used in training. - - Raises: - ValueError: if the pl_module has neither a .mdoel or .module - - Returns: - str: the attr-name of the nn.Module - """ - if hasattr(pl_module, 'module'): - return 'module' - elif hasattr(pl_module, 'model'): - return 'model' - else: - raise ValueError("Expected lightning_module to have a .model or .module attr.") - - -def listify(x): - """Wraps input in a list, if not already a list. - - Args: - x (Anything): the input, can be anything. - - Returns: - Anything | list(Anything): Anything (if it's already a list) o/w list(Anything) - """ - if not isinstance(x, list): - return [x] - return x - - -def get_modules_from_selector(model, module_selector): - """Iterator over model's modules whose FQN match the module_selector. - - Args: - model (nn.Module): the model to iterate over. - module_selector (str): module selector, if empty or '*' will return the whole model. If - there's an asterisk in the name will match it as a regexp. - - Raises: - AttributeError: if the user provides an invalid selector. - AttributeError: if user's selector selects a non-nn.Module attribute. - - Yields: - Iterator(nn.Module): iterator over modules whose FQN matches module_selector - """ - if module_selector is None or module_selector == '' or module_selector == '*': - yield model - return - - assert isinstance(module_selector, str), module_selector - atoms: List[str] = module_selector.split('.') - tmp = model - - for i, item in enumerate(atoms): - if '*' in item: - # handle wildcard selector - # TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv - for name, module in tmp.named_children(): - if re.match(item.replace('*', '.*'), name): - yield module - return - - if not hasattr(tmp, item): - raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`") - tmp = getattr(tmp, item) - - if not isinstance(tmp, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - - yield tmp - - -def compile_module(config, module): - """Jit-compiles an nn.Module - - Args: - config (JitConfig): jit config - module (nn.Module): the module to be compiled - - Returns: - nn.Module: the (potentially) compiled module - """ - if config.use_torch: - module.compile(**config.torch_kwargs) - return True - elif config.use_thunder: - import thunder - import thunder.dynamo - from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform - - # With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just - # consists of `call_function` nodes only and no `call_module` node. - # This is the default setting in PyTorch 2.5 onwards - # (see https://github.com/pytorch/pytorch/pull/131275) - torch._dynamo.config.inline_inbuilt_nn_modules = True - - xforms: list = [NvtxProfileTransform()] if config.profile_thunder else [] - module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms)) - return True - else: - return False - - -@dataclass -class JitConfig: - """Config POD for Jit transforms (e.g. torch.compile or thunder) - Options: - - module_selector (str): reg-exp to match modules to apply JitTransform to, useful for multi-trunk - models where you want to apply it on one of them only. If empty will apply transform to root - module. - - use_torch (bool): whether to use torch.compile or not. - - torch_kwargs (dict): kwargs to pass to torch.compile. - - use_thunder (bool): whether to use thunder or not. - - profile_thunder (bool): toggle for thunder's profiler. - """ - - module_selector: str = '' - use_torch: bool = False - torch_kwargs: dict = field(default_factory=dict) - use_thunder: bool = False - profile_thunder: bool = False - - def __post_init__(self): - assert not (self.use_torch and self.use_thunder), "use_torch cannot be used at the same time with use_thunder" - - -class JitTransform(Callback, IOMixin): - """ - Apply JIT-compling on PyTorch model - - Args: - config (JitConfig): The jit-compiler config to use. - - Example: - >>> from nemo.lightning.pytorch.callbacks import JitTransform - >>> trainer = Trainer(callbacks=[JitTransform(JitConfig(use_torch=True))]) - """ - - def __init__(self, config: JitConfig): - assert config is not None - self.config = config - assert not (self.config.use_torch and self.config.use_thunder) - - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Jit-compiles the model at the start of the epoch. - While other events such as on_train_start are more suitable, we use on_train_epoch_start - since that is what is used in peft (we want to jit after adding the adapters). - - Args: - trainer (pl.Trainer): PTL trainer - pl_module (pl.LightningModule): PTL module - """ - if self.config is None: - return - if not self.config.use_thunder and not self.config.use_torch: - return - - attr_name = extract_module_attr_name(pl_module) - model = getattr(pl_module, attr_name) - - if getattr(pl_module, '_compiled', False) == True: - return - - # TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions - # and trigger the compile if anyone matches, instead of iterating over all O(N^2). - compiled = False - for config in listify(self.config): - for module in get_modules_from_selector(model, config.module_selector): - compiled |= compile_module(config, module) - - setattr(pl_module, '_compiled', compiled) diff --git a/tests/lightning/pytorch/callbacks/test_jit_transform.py b/tests/automodel/compiler/test_compiler.py similarity index 51% rename from tests/lightning/pytorch/callbacks/test_jit_transform.py rename to tests/automodel/compiler/test_compiler.py index def78ba1c203..89c27f8d839c 100644 --- a/tests/lightning/pytorch/callbacks/test_jit_transform.py +++ b/tests/automodel/compiler/test_compiler.py @@ -18,15 +18,15 @@ import re from dataclasses import dataclass, field -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch -from nemo.lightning.pytorch.callbacks.jit_transform import ( - JitConfig, - JitTransform, +from nemo.automodel.compiler import ( + TorchCompileConfig, compile_module, + compile_module_from_config, extract_module_attr_name, get_modules_from_selector, listify, @@ -112,64 +112,106 @@ def test_get_modules_from_selector_wildcard_children(): def test_jit_config_assertion(): - # Should raise if both use_torch and use_thunder - with pytest.raises(AssertionError): - JitConfig(use_torch=True, use_thunder=True) + # Should raise if not TorchCompileConfig / ThunderConfig + mock_module = MagicMock() + with pytest.raises(ValueError): + compile_module({}, mock_module) -def test_compile_module_torch(): - mock_module = MagicMock() - config = JitConfig(use_torch=True, torch_kwargs={"some_arg": 123}) - compiled = compile_module(config, mock_module) - mock_module.compile.assert_called_once_with(some_arg=123) - assert compiled +@pytest.fixture +def mock_module_torch(): + # Create a mock that pretends to be an nn.Module + pl_module = MagicMock(spec=torch.nn.Module) + # Make sure there's a .model attribute that we can access + pl_module.model = MagicMock(spec=torch.nn.Module) + # Initially, pretend it is not yet compiled + pl_module._compiled = False + return pl_module -def test_compile_module_thunder(): - mock_module = MagicMock() - config = JitConfig(use_thunder=True) - compiled = compile_module(config, mock_module) - mock_module.compile.assert_called_once() - assert compiled +def test_compile_module_torch(mock_module_torch): + config = TorchCompileConfig(kwargs={"some_arg": 123}) + compile_module_from_config(config, mock_module_torch) + mock_module_torch.compile.assert_called_once_with(some_arg=123) + assert mock_module_torch._compiled == True + + +def test_compile_module_torch_with_path(mock_module_torch): + config = TorchCompileConfig(module_selector="block1", kwargs={"some_arg": 123}) + + # Ensure there's a `block1` attribute on the module so getattr(module, "block1") works + mock_module_torch.block1 = MagicMock(spec=torch.nn.Module) + + with ( + patch("nemo.automodel.compiler.module_compiler.extract_module_attr_name", return_value="model"), + patch( + "nemo.automodel.compiler.module_compiler.get_modules_from_selector", + return_value=[mock_module_torch.block1], + ), + patch("nemo.automodel.compiler.module_compiler.compile_module") as mock_compile, + ): + + compile_module_from_config(config, mock_module_torch) + mock_compile.assert_called_once_with(config, mock_module_torch.block1) + + # Now ensure _compiled was set to True + assert mock_module_torch._compiled is True def test_compile_module_none(): mock_module = MagicMock() - config = JitConfig() - compiled = compile_module(config, mock_module) + config = None + with pytest.raises(ValueError): + compile_module(config, mock_module) mock_module.compile.assert_not_called() - assert not compiled - -def test_jit_transform_no_config(): - # If config is None, on_train_epoch_start returns early - transform = JitTransform(JitConfig(use_thunder=False, use_torch=False)) - trainer_mock = MagicMock() - pl_module = MagicMock(spec=[]) - transform.on_train_epoch_start(trainer_mock, pl_module) - assert not getattr(pl_module, '_compiled', False) +@pytest.mark.parametrize("config_class", [TorchCompileConfig]) +def test_compile_sets_compiled_flag(config_class): + # Arrange + pl_module = MagicMock() + # By default, pretend it's not compiled yet: + setattr(pl_module, "_compiled", False) + config = config_class() -def test_jit_transform_already_compiled(): - transform = JitTransform(JitConfig(use_torch=True)) - trainer_mock = MagicMock() - pl_module = MagicMock(spec=[]) - pl_module._compiled = True - pl_module.module = True - transform.on_train_epoch_start(trainer_mock, pl_module) - # Should remain True, and compile should not be called again - assert pl_module._compiled is True - assert pl_module.module == True + # Mock out dependencies + with ( + patch("nemo.automodel.compiler.extract_module_attr_name", return_value="model"), + patch("nemo.automodel.compiler.get_modules_from_selector", return_value=[MagicMock()]), + patch("nemo.automodel.compiler.compile_module"), + ): + # Act + compile_module_from_config(config, pl_module) + # Assert + assert getattr(pl_module, "_compiled") is True -def test_jit_transform_compile_once(): - # simulate successful compile (torch or thunder) - transform = JitTransform(JitConfig(use_torch=True)) - trainer_mock = MagicMock() - # pl_module with the 'module' attribute (matching whatever name you expect inside transform) +def test_compile_does_not_set_compiled_when_config_is_none(): + # Arrange pl_module = MagicMock() - pl_module.module = MagicMock() + setattr(pl_module, "_compiled", False) + + # Act + compile_module_from_config(None, pl_module) - transform.on_train_epoch_start(trainer_mock, pl_module) - assert pl_module._compiled is True + # Assert + assert getattr(pl_module, "_compiled") is False + + +def test_compile_skips_if_already_compiled(): + # Arrange + pl_module = MagicMock() + setattr(pl_module, "_compiled", True) + config = TorchCompileConfig() + + with ( + patch("nemo.automodel.compiler.utils.extract_module_attr_name", return_value="model"), + patch("nemo.automodel.compiler.get_modules_from_selector") as mock_selector, + ): + # Act + compile_module_from_config(config, pl_module) + + # Assert: no further calls should happen if _compiled was True + mock_selector.assert_not_called() + assert getattr(pl_module, "_compiled") is True # remains True, unchanged diff --git a/tests/collections/llm/hf/peft_hf.py b/tests/collections/llm/hf/peft_hf.py index 5f1578f7803f..685d3e9ad51a 100644 --- a/tests/collections/llm/hf/peft_hf.py +++ b/tests/collections/llm/hf/peft_hf.py @@ -27,7 +27,6 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.lightning import NeMoLogger -from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform from nemo.lightning.pytorch.strategies.utils import to_cpu DATA_PATH = '/home/TestData/lite/hf_cache/squad/' @@ -279,11 +278,13 @@ def main(): ) callbacks = [] + jit_config = None if args.use_torch_jit: - jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False) - callbacks = [JitTransform(jit_config)] + from nemo.automodel.compiler import TorchCompileConfig - model = llm.HFAutoModelForCausalLM(model_name=args.model) + jit_config = TorchCompileConfig(kwargs={'dynamic': True}) + + model = llm.HFAutoModelForCausalLM(model_name=args.model, compiler_config=jit_config) strategy = make_strategy(args.strategy, model, args.devices, args.num_nodes, True) if args.auto_resume: diff --git a/tests/collections/llm/hf/sft.py b/tests/collections/llm/hf/sft.py index c76406d5f8e6..a85ebbb5d849 100755 --- a/tests/collections/llm/hf/sft.py +++ b/tests/collections/llm/hf/sft.py @@ -27,7 +27,6 @@ from nemo import lightning as nl from nemo.collections import llm -from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform from nemo.lightning.pytorch.strategies.utils import to_cpu DATA_PATH = '/home/TestData/lite/hf_cache/squad/' @@ -322,9 +321,11 @@ def main(): model_accelerator = TEConfig(fp8_autocast=args.fp8_autocast) callbacks = [] + jit_config = None if args.use_torch_jit: - jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False) - callbacks = [JitTransform(jit_config)] + from nemo.automodel.compiler import TorchCompileConfig + + jit_config = TorchCompileConfig(kwargs={'dynamic': True}) if args.auto_resume: callbacks.append(ValidateCheckpointRestoreCallback()) @@ -337,7 +338,7 @@ def configure_module(self, *args, **kwargs): return ans model_cls = ZeroInitHFAutoModelForCausalLM if args.auto_resume else llm.HFAutoModelForCausalLM - model = model_cls(model_name=args.model, model_accelerator=model_accelerator) + model = model_cls(model_name=args.model, model_accelerator=model_accelerator, compiler_config=jit_config) strategy = make_strategy(args.strategy, model, args.devices, args.num_nodes, False) diff --git a/tests/collections/llm/test_nemo_jit_cb.py b/tests/collections/llm/test_nemo_jit_cb.py deleted file mode 100644 index 89859333541d..000000000000 --- a/tests/collections/llm/test_nemo_jit_cb.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import itertools - -import fiddle as fdl -import lightning as pl -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -from nemo import lightning as nl -from nemo.collections import llm -from nemo.collections.llm import fn -from nemo.lightning import io -from nemo.lightning.io.mixin import track_io -from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform - -DATA_PATH = '/home/TestData/lite/hf_cache/squad/' - - -def make_squad_hf_dataset(data_path, tokenizer): - tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) - - def fmt(examples): - prompt = """ - ### Instruction: - {} - - ### Input: - {} - - ### Response: - {}""" - instruction = examples["context"] - input = examples["question"] - output = examples["answers"]['text'] - if isinstance(output, list): - output = output[0] - text = prompt.format(instruction, input, output) + "" - tokens = tokenizer.text_to_ids(text) - return {'input_ids': tokens, 'labels': tokens} - - datamodule = llm.HFDatasetDataModule(data_path, split="train[:100]", pad_token_id=tokenizer.eos_id) - - datamodule.map( - fmt, - batched=False, - batch_size=2, - remove_columns=["id", "title", "context", "question", 'answers'], - ) - - return datamodule - - -@track_io -class OrdTokenizer: - def __init__(self, vocab_size=30_000, num_reserved_tokens=128, special_token_names=['bos_id', 'eos_id', 'pad_id']): - self.vocab_size = vocab_size - self.num_reserved_tokens = num_reserved_tokens - self.special_token_names = special_token_names - assert len(self.special_token_names) < num_reserved_tokens - - def __getattr__(self, name): - if name in self.__dict__.get('special_token_names', {}): - return self.__dict__['special_token_names'].index(name) - elif name in self.__dict__: - return self.__dict__[name] - else: - raise AttributeError - - def text_to_ids(self, text): - token_ids = list(map(lambda x: self.num_reserved_tokens + ord(x), list(text))) - assert max(token_ids) < self.vocab_size - return token_ids - - -def align_labels(logits, labels): - logits = logits.float() - n_cls = logits.shape[-1] - if logits.shape[-2] == labels.shape[-1]: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - elif logits.shape[-2] == labels.shape[-1] + 1: - logits = logits[..., :-1, :].contiguous() - else: - raise ValueError("Mismatched labels and logits shapes (" + str(labels.shape) + " " + str(logits.shape)) - return logits.view(-1, n_cls), labels.view(-1) - - -class DummyJitModel(pl.LightningModule, io.IOMixin, fn.FNMixin): - def __init__( - self, - tokenizer=None, - has_jit=False, - ): - super().__init__() - self.has_jit = has_jit - self.tokenizer = tokenizer - - def configure_model(self) -> None: - if not hasattr(self, "module"): - self.module = nn.Sequential( - nn.Embedding(30_000, 512), - nn.TransformerEncoderLayer(512, 8, 4096, dropout=0.1), - nn.Linear(512, 30_000), - ) - - def forward(self, batch): - output = self.module(**batch) - if self.has_jit: - assert self.module._compiled_call_impl is not None - assert callable(self.module._compiled_call_impl) - else: - assert self.module._compiled_call_impl is None - expected_cls = torch.nn.modules.container.Sequential - assert isinstance(self.module, expected_cls), type(self.module) - return output - - def training_step(self, batch): - if self.has_jit: - assert hasattr(self, '_compiled') - assert self._compiled == True, self._compiled - else: - assert not hasattr(self, '_compiled') - labels = batch.pop('labels') - loss_mask = batch.get('loss_mask', None) - output = self.forward({'input': batch['input_ids']}) - logits, labels = align_labels(output, labels) - return F.cross_entropy(logits, labels) - - -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument('--devices', default=1) - parser.add_argument('--max-steps', type=int, default=1) - args = parser.parse_args() - - tokenizer = OrdTokenizer() - data = make_squad_hf_dataset(DATA_PATH, tokenizer) - - for use_torch, use_thunder in itertools.product([True, False], [False, False]): - if use_torch and use_thunder: - continue - model = DummyJitModel(tokenizer=tokenizer, has_jit=use_torch | use_thunder) - optim = fdl.build(llm.sgd.pytorch_sgd_with_flat_lr(lr=1e-5)) - - jit_config = JitConfig(use_torch=use_torch, use_thunder=use_thunder) - transform = JitTransform(jit_config) - - llm.api.finetune( - model=model, - data=data, - trainer=nl.Trainer( - devices=args.devices, - max_steps=args.max_steps, - accelerator='gpu', - strategy='auto', - log_every_n_steps=1, - limit_val_batches=0.0, - num_sanity_val_steps=0, - accumulate_grad_batches=1, - gradient_clip_val=1.0, - use_distributed_sampler=False, - callbacks=[transform], - ), - optim=optim, - log=None, - )