-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[automodel] Move jittransform to library #12880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
b15a451
init deprecation of jittransform
akoumpa 6283c90
Apply isort and black reformatting
akoumpa b7c2caa
Potential fix for code scanning alert no. 13783: Use of the return va…
akoumpa 3a0f32d
Potential fix for code scanning alert no. 13786: Unused import
akoumpa d47bf55
Potential fix for code scanning alert no. 13793: Unused import
akoumpa cd2827f
update doc
akoumpa b70195f
introduce compiler_config
akoumpa 5ab58e1
test moved to tests/automodel/compiler/test_compiler.py
akoumpa adc60b9
fix
akoumpa ec7cbdb
fix
akoumpa de15545
Apply isort and black reformatting
akoumpa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # VISION | ||
| from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig | ||
| from nemo.automodel.compiler.module_compiler import compile_module, compile_module_from_config | ||
| from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector, listify | ||
|
|
||
| __all__ = [ | ||
| "TorchCompileConfig", | ||
| "ThunderConfig", | ||
| "compile_module", | ||
| "compile_module_from_config", | ||
| "extract_module_attr_name", | ||
| "listify", | ||
| "get_modules_from_selector", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass, field | ||
|
|
||
|
|
||
| @dataclass | ||
| class TorchCompileConfig: | ||
| """Config for torch.compile | ||
| Options: | ||
| - module_selector (str): reg-exp to match modules to compile, useful for multi-trunk | ||
| models where you want to apply it on one of them only. If empty will apply transform to root | ||
| module. | ||
| - apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa. | ||
| - kwargs (dict): kwargs to pass to torch.compile. | ||
| """ | ||
|
|
||
| module_selector: str = '' | ||
| apply_pre_wrap: bool = True | ||
| kwargs: dict = field(default_factory=dict) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ThunderConfig: | ||
| """Config for Thunder | ||
| Options: | ||
| - module_selector (str): reg-exp to match modules to compile, useful for multi-trunk | ||
| models where you want to apply it on one of them only. If empty will apply transform to root | ||
| module. | ||
| - apply_pre_wrap: if True will compile before wrapping with DDP/FSDP2 & vice-versa. | ||
| - kwargs (dict): kwargs to pass to thunder, currently unused. | ||
| - profile (bool): toggle for thunder's profiler. | ||
| """ | ||
|
|
||
| module_selector: str = '' | ||
| apply_pre_wrap: bool = True | ||
| kwargs: dict = field(default_factory=dict) | ||
| profile: bool = False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import re | ||
|
||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
||
|
|
||
| from nemo.automodel.compiler.configs import ThunderConfig, TorchCompileConfig | ||
| from nemo.automodel.compiler.utils import extract_module_attr_name, get_modules_from_selector | ||
|
|
||
|
|
||
| def compile_module(config, module): | ||
| """Jit-compiles an nn.Module | ||
|
|
||
| Args: | ||
| config (JitConfig): jit config | ||
| module (nn.Module): the module to be compiled | ||
|
|
||
| Returns: | ||
| nn.Module: the (potentially) compiled module | ||
| """ | ||
| if isinstance(config, TorchCompileConfig): | ||
| module.compile(**(config.kwargs or {})) | ||
| elif isinstance(config, ThunderConfig): | ||
| import thunder | ||
| import thunder.dynamo | ||
| from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform | ||
|
|
||
| # With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just | ||
| # consists of `call_function` nodes only and no `call_module` node. | ||
| # This is the default setting in PyTorch 2.5 onwards | ||
| # (see https://github.com/pytorch/pytorch/pull/131275) | ||
| torch._dynamo.config.inline_inbuilt_nn_modules = True | ||
|
|
||
| xforms: list = [NvtxProfileTransform()] if config.profile else [] | ||
| module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms)) | ||
| else: | ||
| raise ValueError("Expected config to be TorchCompileConfig or ThunderConfig") | ||
|
|
||
|
|
||
| def compile_module_from_config(config, module) -> None: | ||
| """Jit-compiles the model at the start of the epoch. | ||
| While other events such as on_train_start are more suitable, we use on_train_epoch_start | ||
| since that is what is used in peft (we want to jit after adding the adapters). | ||
|
|
||
| Args: | ||
| module (nn.Module): the nn.Module to compile. | ||
| """ | ||
| if config is None: | ||
| return | ||
| if not isinstance(config, (TorchCompileConfig, ThunderConfig)): | ||
| return | ||
|
|
||
| attr_name = extract_module_attr_name(module) | ||
| model = getattr(module, attr_name) | ||
|
|
||
| if getattr(module, '_compiled', False) == True: | ||
| return | ||
|
|
||
| # TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions | ||
| # and trigger the compile if anyone matches, instead of iterating over all O(N^2). | ||
| compiled = False | ||
| for module in get_modules_from_selector(model, config.module_selector): | ||
| compile_module(config, module) | ||
| compiled = True | ||
|
|
||
| setattr(module, '_compiled', compiled) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import re | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| def extract_module_attr_name(pl_module: "pl.LightningModule") -> str: | ||
| """Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail. | ||
|
|
||
| Args: | ||
| pl_module (pl.LightningModule): the LightningModule used in training. | ||
|
|
||
| Raises: | ||
| ValueError: if the pl_module has neither a .mdoel or .module | ||
|
|
||
| Returns: | ||
| str: the attr-name of the nn.Module | ||
| """ | ||
| if hasattr(pl_module, 'module'): | ||
| return 'module' | ||
| elif hasattr(pl_module, 'model'): | ||
| return 'model' | ||
| elif isinstance(pl_module, nn.Module): | ||
| return pl_module | ||
| else: | ||
| raise ValueError("Expected lightning_module to have a .model or .module attr.") | ||
|
|
||
|
|
||
| def listify(x): | ||
| """Wraps input in a list, if not already a list. | ||
|
|
||
| Args: | ||
| x (Anything): the input, can be anything. | ||
|
|
||
| Returns: | ||
| Anything | list(Anything): Anything (if it's already a list) o/w list(Anything) | ||
| """ | ||
| if not isinstance(x, list): | ||
| return [x] | ||
| return x | ||
|
|
||
|
|
||
| def get_modules_from_selector(model, module_selector): | ||
| """Iterator over model's modules whose FQN match the module_selector. | ||
|
|
||
| Args: | ||
| model (nn.Module): the model to iterate over. | ||
| module_selector (str): module selector, if empty or '*' will return the whole model. If | ||
| there's an asterisk in the name will match it as a regexp. | ||
|
|
||
| Raises: | ||
| AttributeError: if the user provides an invalid selector. | ||
| AttributeError: if user's selector selects a non-nn.Module attribute. | ||
|
|
||
| Yields: | ||
| Iterator(nn.Module): iterator over modules whose FQN matches module_selector | ||
| """ | ||
| if module_selector is None or module_selector == '' or module_selector == '*': | ||
| yield model | ||
| return | ||
|
|
||
| assert isinstance(module_selector, str), module_selector | ||
| atoms: List[str] = module_selector.split('.') | ||
| tmp = model | ||
|
|
||
| for i, item in enumerate(atoms): | ||
| if '*' in item: | ||
| # handle wildcard selector | ||
| # TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv | ||
| for name, module in tmp.named_children(): | ||
| if re.match(item.replace('*', '.*'), name): | ||
| yield module | ||
| return | ||
|
|
||
| if not hasattr(tmp, item): | ||
| raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`") | ||
| tmp = getattr(tmp, item) | ||
|
|
||
| if not isinstance(tmp, torch.nn.Module): | ||
| raise AttributeError("`" + item + "` is not " "an nn.Module") | ||
|
|
||
| yield tmp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.