Skip to content

[automodel] Move jittransform to library #12880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
28 changes: 28 additions & 0 deletions nemo/automodel/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# VISION
from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig
from nemo.automodel.compiler.module_compiler import compile_module, compile_module_from_config
from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector, listify

__all__ = [
"TorchCompileConfig",
"ThunderConfig",
"compile_module",
"compile_module_from_config",
"extract_module_attr_name",
"listify",
"get_modules_from_selector",
]
49 changes: 49 additions & 0 deletions nemo/automodel/compiler/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field


@dataclass
class TorchCompileConfig:
"""Config for torch.compile
Options:
- module_selector (str): reg-exp to match modules to compile, useful for multi-trunk
models where you want to apply it on one of them only. If empty will apply transform to root
module.
- apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa.
- kwargs (dict): kwargs to pass to torch.compile.
"""

module_selector: str = ''
apply_pre_wrap: bool = True
kwargs: dict = field(default_factory=dict)


@dataclass
class ThunderConfig:
"""Config for Thunder
Options:
- module_selector (str): reg-exp to match modules to compile, useful for multi-trunk
models where you want to apply it on one of them only. If empty will apply transform to root
module.
- apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa.
- kwargs (dict): kwargs to pass to thunder, currently unused.
- profile (bool): toggle for thunder's profiler.
"""

module_selector: str = ''
apply_pre_wrap: bool = True
kwargs: dict = field(default_factory=dict)
profile: bool = False
79 changes: 79 additions & 0 deletions nemo/automodel/compiler/module_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import torch


from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig
from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector


def compile_module(config, module):
"""Jit-compiles an nn.Module

Args:
config (JitConfig): jit config
module (nn.Module): the module to be compiled

Returns:
nn.Module: the (potentially) compiled module
"""
if isinstance(config, TorchCompileConfig):
module.compile(**(config.kwargs or {}))
elif isinstance(config, ThunderConfig):
import thunder
import thunder.dynamo
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

# With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just
# consists of `call_function` nodes only and no `call_module` node.
# This is the default setting in PyTorch 2.5 onwards
# (see https://github.com/pytorch/pytorch/pull/131275)
torch._dynamo.config.inline_inbuilt_nn_modules = True

xforms: list = [NvtxProfileTransform()] if config.profile else []
module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms))
else:
raise ValueError("Expected config to be TorchCompileConfig or ThunderConfig")


def compile_module_from_config(config, module) -> None:
"""Jit-compiles the model at the start of the epoch.
While other events such as on_train_start are more suitable, we use on_train_epoch_start
since that is what is used in peft (we want to jit after adding the adapters).

Args:
module (nn.Module): the nn.Module to compile.
"""
if config is None:
return
if not isinstance(config, (TorchCompileConfig, ThunderConfig)):
return

attr_name = extract_module_attr_name(module)
model = getattr(module, attr_name)

if getattr(module, '_compiled', False) == True:
return

# TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions
# and trigger the compile if anyone matches, instead of iterating over all O(N^2).
compiled = False
for module in get_modules_from_selector(model, config.module_selector):
compile_module(config, module)
compiled = True

setattr(module, '_compiled', compiled)
95 changes: 95 additions & 0 deletions nemo/automodel/compiler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import torch
import torch.nn as nn


def extract_module_attr_name(pl_module: "pl.LightningModule") -> str:
"""Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail.

Args:
pl_module (pl.LightningModule): the LightningModule used in training.

Raises:
ValueError: if the pl_module has neither a .mdoel or .module

Returns:
str: the attr-name of the nn.Module
"""
if hasattr(pl_module, 'module'):
return 'module'
elif hasattr(pl_module, 'model'):
return 'model'
elif isinstance(pl_module, nn.Module):
return pl_module
else:
raise ValueError("Expected lightning_module to have a .model or .module attr.")


def listify(x):
"""Wraps input in a list, if not already a list.

Args:
x (Anything): the input, can be anything.

Returns:
Anything | list(Anything): Anything (if it's already a list) o/w list(Anything)
"""
if not isinstance(x, list):
return [x]
return x


def get_modules_from_selector(model, module_selector):
"""Iterator over model's modules whose FQN match the module_selector.

Args:
model (nn.Module): the model to iterate over.
module_selector (str): module selector, if empty or '*' will return the whole model. If
there's an asterisk in the name will match it as a regexp.

Raises:
AttributeError: if the user provides an invalid selector.
AttributeError: if user's selector selects a non-nn.Module attribute.

Yields:
Iterator(nn.Module): iterator over modules whose FQN matches module_selector
"""
if module_selector is None or module_selector == '' or module_selector == '*':
yield model
return

assert isinstance(module_selector, str), module_selector
atoms: List[str] = module_selector.split('.')
tmp = model

for i, item in enumerate(atoms):
if '*' in item:
# handle wildcard selector
# TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv
for name, module in tmp.named_children():
if re.match(item.replace('*', '.*'), name):
yield module
return

if not hasattr(tmp, item):
raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`")
tmp = getattr(tmp, item)

if not isinstance(tmp, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")

yield tmp
10 changes: 1 addition & 9 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -1079,14 +1079,6 @@ def _setup(
trainer.callbacks.append(model_transform)
else:
trainer.callbacks.append(ModelTransform())
# Move jit callback at the end ensure it's applied on top of any model transformations (peft)
jit_cb = None
for i, cb in enumerate(trainer.callbacks):
if isinstance(cb, JitTransform):
assert jit_cb is None
jit_cb = trainer.callbacks.pop(i)
if jit_cb is not None:
trainer.callbacks.append(jit_cb)
return app_state


Expand Down
3 changes: 0 additions & 3 deletions nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker
from nemo.lightning.pytorch.callbacks.debugging import ParameterDebugger
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.jit_transform import JitConfig, JitTransform
from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks.model_callback import ModelCallback
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
Expand All @@ -39,6 +38,4 @@
"GarbageCollectionCallback",
"ParameterDebugger",
"ModelCallback",
"JitTransform",
"JitConfig",
]
Loading
Loading