Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] LSTM error using torch tensorrt #3427

Open
zmy1116 opened this issue Mar 4, 2025 · 0 comments
Open

🐛 [Bug] LSTM error using torch tensorrt #3427

zmy1116 opened this issue Mar 4, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@zmy1116
Copy link

zmy1116 commented Mar 4, 2025

Bug Description

error when compile an LSTM model using torch tensorrt

To Reproduce

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
os.environ['CUDA_VISIBLE_DEVICES']= '3'


model = nn.LSTM(60, 128, num_layers=4, batch_first=True, dropout=0.5, bidirectional=True)
model = model.eval().cuda()

trt_model = torch_tensorrt.compile(model,
    inputs= [torch_tensorrt.Input((1,2,60)) ],
    enabled_precisions= {torch_tensorrt.dtype.f16}
)

errors

[WARNING  | root               ]: Given dtype that does not have direct mapping to torch (dtype.unknown), defaulting to torch.float
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 12
      9 model = nn.LSTM(60, 128, num_layers=4, batch_first=True, dropout=0.5, bidirectional=True)
     10 model = model.eval().cuda()
---> 12 trt_model = torch_tensorrt.compile(model,
     13     inputs= [torch_tensorrt.Input((1,2,60)) ])

File [/usr/local/lib/python3.12/dist-packages/torch_tensorrt/_compile.py:286](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch_tensorrt/_compile.py#line=285), in compile(module, ir, inputs, arg_inputs, kwarg_inputs, enabled_precisions, **kwargs)
    283 torchtrt_arg_inputs = prepare_inputs(arg_inputs)
    284 torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
--> 286 exp_program = dynamo_trace(
    287     module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
    288 )
    289 trt_graph_module = dynamo_compile(
    290     exp_program,
    291     arg_inputs=torchtrt_arg_inputs,
    292     enabled_precisions=enabled_precisions_set,
    293     **kwargs,
    294 )
    295 return trt_graph_module

File [/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_tracer.py:83](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_tracer.py#line=82), in trace(mod, inputs, arg_inputs, kwarg_inputs, **kwargs)
     81 dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
     82 dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
---> 83 exp_program = export(
     84     mod,
     85     tuple(torch_arg_inputs),
     86     kwargs=torch_kwarg_inputs,
     87     dynamic_shapes=dynamic_shapes,
     88 )
     90 return exp_program

File [/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py:368](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py#line=367), in export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)
    362 if isinstance(mod, torch.jit.ScriptModule):
    363     raise ValueError(
    364         "Exporting a ScriptModule is not supported. "
    365         "Maybe try converting your ScriptModule to an ExportedProgram "
    366         "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
    367     )
--> 368 return _export(
    369     mod,
    370     args,
    371     kwargs,
    372     dynamic_shapes,
    373     strict=strict,
    374     preserve_module_call_signature=preserve_module_call_signature,
    375     pre_dispatch=True,
    376 )

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1038](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1037), in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1031     else:
   1032         log_export_usage(
   1033             event="export.error.unclassified",
   1034             type=error_type,
   1035             message=str(e),
   1036             flags=_EXPORT_FLAGS,
   1037         )
-> 1038     raise e
   1039 finally:
   1040     _EXPORT_FLAGS = None

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1011](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1010), in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1009 try:
   1010     start = time.time()
-> 1011     ep = fn(*args, **kwargs)
   1012     end = time.time()
   1013     log_export_usage(
   1014         event="export.time",
   1015         metrics=end - start,
   1016         flags=_EXPORT_FLAGS,
   1017         **get_ep_stats(ep),
   1018     )

File [/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py:128](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py#line=127), in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
    125 @functools.wraps(fn)
    126 def wrapper(*args, **kwargs):
    127     with unset_fake_temporarily():
--> 128         return fn(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:2057](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=2056), in _export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)
   2049 # NOTE Export training IR rollout
   2050 # Old export calls export._trace(pre_dispatch=True)
   2051 # and there are still lot of internal[/OSS](http://34.72.49.31:8889/OSS) callsites that
   (...)
   2054 # export_training_ir_rollout_check returns True in OSS
   2055 # while internally it returns False UNLESS otherwise specified.
   2056 if pre_dispatch and export_training_ir_rollout_check():
-> 2057     return _export_for_training(
   2058         mod,
   2059         args,
   2060         kwargs,
   2061         dynamic_shapes,
   2062         strict=strict,
   2063         preserve_module_call_signature=preserve_module_call_signature,
   2064     )
   2066 (
   2067     args,
   2068     kwargs,
   2069     original_in_spec,
   2070     dynamic_shapes,
   2071 ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
   2073 original_state_dict = _get_original_state_dict(mod)

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1038](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1037), in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1031     else:
   1032         log_export_usage(
   1033             event="export.error.unclassified",
   1034             type=error_type,
   1035             message=str(e),
   1036             flags=_EXPORT_FLAGS,
   1037         )
-> 1038     raise e
   1039 finally:
   1040     _EXPORT_FLAGS = None

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1011](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1010), in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1009 try:
   1010     start = time.time()
-> 1011     ep = fn(*args, **kwargs)
   1012     end = time.time()
   1013     log_export_usage(
   1014         event="export.time",
   1015         metrics=end - start,
   1016         flags=_EXPORT_FLAGS,
   1017         **get_ep_stats(ep),
   1018     )

File [/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py:128](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py#line=127), in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
    125 @functools.wraps(fn)
    126 def wrapper(*args, **kwargs):
    127     with unset_fake_temporarily():
--> 128         return fn(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1921](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1920), in _export_for_training(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)
   1908 original_state_dict = _get_original_state_dict(mod)
   1910 export_func = (
   1911     functools.partial(
   1912         _strict_export_lower_to_aten_ir,
   (...)
   1919     )
   1920 )
-> 1921 export_artifact = export_func(  # type: ignore[operator]
   1922     mod=mod,
   1923     args=args,
   1924     kwargs=kwargs,
   1925     dynamic_shapes=dynamic_shapes,
   1926     preserve_module_call_signature=preserve_module_call_signature,
   1927     pre_dispatch=False,
   1928     original_state_dict=original_state_dict,
   1929     orig_in_spec=orig_in_spec,
   1930     allow_complex_guards_as_runtime_asserts=False,
   1931     _is_torch_jit_trace=False,
   1932 )
   1934 export_graph_signature = export_artifact.aten.sig
   1936 forward_arg_names = _get_forward_arg_names(mod, args, kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:1290](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=1289), in _strict_export_lower_to_aten_ir(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, lower_to_aten_callback)
   1277 def _strict_export_lower_to_aten_ir(
   1278     mod: torch.nn.Module,
   1279     args: Tuple[Any, ...],
   (...)
   1288     lower_to_aten_callback: Callable,
   1289 ) -> ExportArtifact:
-> 1290     gm_torch_level = _export_to_torch_ir(
   1291         mod,
   1292         args,
   1293         kwargs,
   1294         dynamic_shapes,
   1295         preserve_module_call_signature=preserve_module_call_signature,
   1296         restore_fqn=False,  # don't need to restore because we will do it later
   1297         allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
   1298         _log_export_usage=False,
   1299     )
   1301     # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
   1302     (
   1303         fake_args,
   1304         fake_kwargs,
   1305         dynamo_fake_mode,
   1306     ) = _extract_fake_inputs(gm_torch_level, args, kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py:674](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py#line=673), in _export_to_torch_ir(f, args, kwargs, dynamic_shapes, preserve_module_call_signature, disable_constraint_solver, allow_complex_guards_as_runtime_asserts, restore_fqn, _log_export_usage, same_signature)
    670         ctx = _wrap_submodules(  # type: ignore[assignment]
    671             f, preserve_module_call_signature, module_call_specs
    672         )
    673     with ctx, _ignore_backend_decomps():
--> 674         gm_torch_level, _ = torch._dynamo.export(
    675             f,
    676             dynamic_shapes=dynamic_shapes,  # type: ignore[arg-type]
    677             assume_static_by_default=True,
    678             tracing_mode="symbolic",
    679             disable_constraint_solver=disable_constraint_solver,
    680             # currently the following 2 flags are tied together for export purposes,
    681             # but untangle for sake of dynamo export api
    682             prefer_deferred_runtime_asserts_over_guards=True,
    683             allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
    684             _log_export_usage=_log_export_usage,
    685             same_signature=same_signature,
    686         )(
    687             *args,
    688             **kwargs,
    689         )
    690 except (ConstraintViolationError, ValueRangeError) as e:
    691     raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904

File [/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:1583](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py#line=1582), in export.<locals>.inner(*args, **kwargs)
   1581 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
   1582 try:
-> 1583     result_traced = opt_f(*args, **kwargs)
   1584 except ConstraintViolationError as e:
   1585     constraint_violation_error = e

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1739](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1738), in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1750](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1749), in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File [/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:576](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py#line=575), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    573 _maybe_set_eval_frame(_callback_from_stance(callback))
    575 try:
--> 576     return fn(*args, **kwargs)
    577 except ShortenTraceback as e:
    578     # Failures in the backend likely don't have useful
    579     # data in the TorchDynamo frames, so we strip them out.
    580     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1

File [/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:46](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py#line=45), in wrap_inline.<locals>.inner(*args, **kwargs)
     41 def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
     42     """
     43     Create an extra frame around fn that is not in skipfiles.
     44     """
---> 46     @functools.wraps(fn)
     47     def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
     48         return fn(*args, **kwargs)
     50     return inner

File [/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:755](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py#line=754), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    753 _maybe_set_eval_frame(_callback_from_stance(self.callback))
    754 try:
--> 755     return fn(*args, **kwargs)
    756 finally:
    757     set_eval_frame(None)

File [/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:1546](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py#line=1545), in export.<locals>.inner.<locals>.dynamo_normalization_capturing_compiler.<locals>.result_capturing_wrapper(*graph_inputs)
   1539         fake_params_buffers[name] = ambient_fake_mode.from_tensor(
   1540             value, static_shapes=True
   1541         )
   1543     fake_graph_inputs = pytree.tree_map(
   1544         ambient_fake_mode.from_tensor, graph_inputs
   1545     )
-> 1546     graph_captured_result = torch.func.functional_call(
   1547         graph, fake_params_buffers, fake_graph_inputs
   1548     )
   1550 return graph_captured_result

File [/usr/local/lib/python3.12/dist-packages/torch/_functorch/functional_call.py:147](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/_functorch/functional_call.py#line=146), in functional_call(module, parameter_and_buffer_dicts, args, kwargs, tie_weights, strict)
    141 else:
    142     raise ValueError(
    143         f"Expected parameter_and_buffer_dicts to be a dict, or a list[/tuple](http://34.72.49.31:8889/tuple) of dicts, "
    144         f"but got {type(parameter_and_buffer_dicts)}"
    145     )
--> 147 return nn.utils.stateless._functional_call(
    148     module,
    149     parameters_and_buffers,
    150     args,
    151     kwargs,
    152     tie_weights=tie_weights,
    153     strict=strict,
    154 )

File [/usr/local/lib/python3.12/dist-packages/torch/nn/utils/stateless.py:282](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/utils/stateless.py#line=281), in _functional_call(module, parameters_and_buffers, args, kwargs, tie_weights, strict)
    278     args = (args,)
    279 with _reparametrize_module(
    280     module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
    281 ):
--> 282     return module(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1739](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1738), in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1750](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1749), in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File [/usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py:126](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py#line=125), in _LazyGraphModule._lazy_forward(self, *args, **kwargs)
    121 assert not self._needs_recompile()
    123 # call `__call__` rather than 'forward' since recompilation may
    124 # install a wrapper for `__call__` to provide a customized error
    125 # message.
--> 126 return self(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py:822](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py#line=821), in GraphModule.recompile.<locals>.call_wrapped(self, *args, **kwargs)
    821 def call_wrapped(self, *args, **kwargs):
--> 822     return self._wrapped_call(self, *args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py:400](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py#line=399), in _WrappedCall.__call__(self, obj, *args, **kwargs)
    398     raise e.with_traceback(None)  # noqa: B904
    399 else:
--> 400     raise e

File [/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py:387](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py#line=386), in _WrappedCall.__call__(self, obj, *args, **kwargs)
    385         return self.cls_call(obj, *args, **kwargs)
    386     else:
--> 387         return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
    388 except Exception as e:
    389     assert e.__traceback__

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1739](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1738), in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1750](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1749), in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File <eval_with_key>.5:6, in forward(self, L_args_0_)
      4 def forward(self, L_args_0_ : torch.Tensor):
      5     l_args_0_ = L_args_0_
----> 6     fn = self.fn(l_args_0_);  l_args_0_ = None
      7     getitem = fn[0]
      8     getitem_1 = fn[1];  fn = None

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1739](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1738), in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1750](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py#line=1749), in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py:1043](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py#line=1042), in LSTM.forward(self, input, hx)
   1042 def forward(self, input, hx=None):  # noqa: F811
-> 1043     self._update_flat_weights()
   1045     orig_input = input
   1046     # xxx: isinstance check needs to be in conditional for TorchScript to compile

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py:392](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py#line=391), in RNNBase._update_flat_weights(self)
    390 if not torch.jit.is_scripting():
    391     if self._weights_have_changed():
--> 392         self._init_flat_weights()

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py:215](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py#line=214), in RNNBase._init_flat_weights(self)
    208 self._flat_weights = [
    209     getattr(self, wn) if hasattr(self, wn) else None
    210     for wn in self._flat_weights_names
    211 ]
    212 self._flat_weight_refs = [
    213     weakref.ref(w) if w is not None else None for w in self._flat_weights
    214 ]
--> 215 self.flatten_parameters()

File [/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py:256](http://34.72.49.31:8889/usr/local/lib/python3.12/dist-packages/torch/nn/modules/rnn.py#line=255), in RNNBase.flatten_parameters(self)
    249         return
    251 # If any parameters alias, we fall back to the slower, copying code path. This is
    252 # a sufficient check, because overlapping parameter buffers that don't completely
    253 # alias would break the assumptions of the uniqueness check in
    254 # Module.named_parameters().
    255 unique_data_ptrs = {
--> 256     p.data_ptr() for p in self._flat_weights  # type: ignore[union-attr]
    257 }
    258 if len(unique_data_ptrs) != len(self._flat_weights):
    259     return

RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile[/export/fx](http://34.72.49.31:8889/export/fx), it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Expected behavior

expect no errors occure. Or maybe I did something wrong.

Also, tbh it's a bit confusing when I try to read the torch tensorrt documentation, there is the dynamo thing, torch scrip thing? the FX thing? FX graph? there is compile, then I have to trace it to save? some type of compile output I can trace it to save, others I can't ? I'm not trying babbling gibberish, but every page under https://pytorch.org/TensorRT there is little tweaks here and there. sometimes you input backend, other times you change ir, another page you start with torch.dynamo ....

Regardless, can't compile lstm

Environment

I'm using the NGC docker nvcr.io/nvidia/pytorch:25.02-py3

Google Clud VM is used with L4 GPU

@zmy1116 zmy1116 added the bug Something isn't working label Mar 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant