Skip to content

Not support torch.compile() ? #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Vieeo opened this issue Mar 5, 2025 · 1 comment
Open

Not support torch.compile() ? #145

Vieeo opened this issue Mar 5, 2025 · 1 comment

Comments

@Vieeo
Copy link

Vieeo commented Mar 5, 2025

I'm training flux model of sparsity with torch.compile().

Basic version info:
python 3.12.0
pytorch 2.5.0
nvidia-modelopt 0.21.0
cuda: 12.6

dit = mto.restore(dit, flux_sparse.pth")
dit = accelerator.prepare(dit)
dit = torch.compile(dit)

Error as follows:

[rank4]: Traceback (most recent call last):
[rank4]: File "/data/train_sat_flux.py", line 380, in
[rank4]: main()
[rank4]: File "/data/train_sat_flux.py", line 286, in main
[rank4]: model_pred = dit(
[rank4]: ^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank4]: return forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 40, in inner
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank4]: return forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 850, in forward
[rank4]: args, kwargs = _pre_forward(
[rank4]: ^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 382, in _pre_forward
[rank4]: unshard_fn(state, handle)
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 417, in _pre_forward_unshard
[rank4]: _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 301, in _unshard
[rank4]: handle.unshard()
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1313, in unshard
[rank4]: self._use_unsharded_flat_param(padded_unsharded_flat_param)
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1444, in _use_unsharded_flat_param
[rank4]: self._use_unsharded_views(
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1928, in _use_unsharded_views
[rank4]: self._setattr_tensor(module, param_name, param_var)
[rank4]: File "/data/modelopt/torch/opt/_hooks.py", line 36, in _safe_setattr_tensor_or_param_with_dm_check
[rank4]: with (
[rank4]: File "/root/miniforge3/envs/py312torch250/lib/python3.12/contextlib.py", line 144, in exit
[rank4]: next(self.gen)
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 753, in reset_dynamic_attributes
[rank4]: self._register_dynamic_attribute(k, lambda _, v: v)
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 429, in _register_dynamic_attribute
[rank4]: def _register_dynamic_attribute(self, name: str, callback: DynamicAttributeCallback):
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 470, in torch_dynamo_resume_in__register_dynamic_attribute_at_449
[rank4]: delattr(self, name)
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 814, in delattr
[rank4]: def delattr(self, name: str):
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 814, in delattr
[rank4]: def delattr(self, name: str):
[rank4]: File "/data/modelopt/torch/opt/dynamic.py", line 814, in delattr
[rank4]: def delattr(self, name: str):
[rank4]: [Previous line repeated 726 more times]
[rank4]: RecursionError: maximum recursion depth exceeded

@cjluo-nv
Copy link
Collaborator

cjluo-nv commented Apr 3, 2025

Our fake quant kernel may not support torch.compile unfortunately. If you need acceleration, we recommend to run this with TRT or AutoDeploy(https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/auto_deploy/build_and_run_flux.py)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants