You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
but when i use model to get output _ = optimized_model(**inputs),it raise the error:
INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')
WARNING:torch_tensorrt.dynamo.backend.backends:TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead.
Traceback (most recent call last):
File "/home/fuyx/.local/lib/python3.8/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 90, in _pretraced_backend
gm = aot_export_joint_simple(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1235, in aot_export_joint_simple
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1350, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 95, in aot_dispatch_export
graph, _, _ = aot_dispatch_base_graph(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 138, in aot_dispatch_base_graph
fw_module = _create_graph(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 46, in _create_graph
fx_g = make_fx(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1421, in wrapped
return make_fx_tracer.trace(f, *args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1367, in trace
return self._trace_inner(f, *args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1354, in _trace_inner
t = dispatch_trace(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 642, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
(self.create_arg(fn(*args)),),
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 660, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 388, in _functionalized_f_helper
f_outs = fn(*f_args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 72, in inner_fn
outs = fn(*args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 178, in flat_fn
tree_out = fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 316, in __call__
raise e
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 569, in call_module
return forward(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.1", line 739, in forward
masked_fill_ = mask.masked_fill_(lt, 0); lt = None
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 705, in __torch_function__
return func(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_device.py", line 79, in __torch_function__
return func(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/functional_tensor.py", line 468, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 329, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 1441, in maybe_handle_decomp
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_prims_common/wrappers.py", line 266, in _fn
result = fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 5600, in masked_fill
r = torch.where(mask, value, a) # type: ignore[arg-type]
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 340, in proxy_call
r = func.decompose(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_ops.py", line 704, in decompose
return self._op_dk(dk, *args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 755, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 790, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py", line 467, in proxy_call
out = func(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_ops.py", line 667, in __call__
return self_._op(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1765, in _dispatch_impl
self.wrap_meta_outputs_with_default_device_logic(
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1875, in wrap_meta_outputs_with_default_device_logic
return tree_map(wrap, r)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_pytree.py", line 948, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/utils/_pytree.py", line 787, in unflatten
leaves = list(leaves)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1853, in wrap
) = FakeTensor._find_common_device(func, flat_args)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 775, in _find_common_device
merge_devices(arg)
File "/home/fuyx/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 770, in merge_devices
raise RuntimeError(
RuntimeError: Unhandled FakeTensor Device Propagation for aten.where.self, found two different devices cuda:0, cpu
so how should i solve this problem
To Reproduce
Steps to reproduce the behavior:
1.compile the model like motioned
2.process the input
3. invoke the model
Expected behavior
finish the forward process successfully
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0):2.4.0+cu118
PyTorch Version (e.g. 1.0):2.4.1+cu118
CPU Architecture: x86_64
OS (e.g., Linux):Ubuntu 20.04.6 LTS
How you installed PyTorch (conda, pip, libtorch, source): pip
Build command you used (if compiling from source):
Are you using local sources or building from archives:
Python version:3.8.10
CUDA version:11.8
GPU models and configuration:RTX 2080 Ti Driver Version: 535.161.08
Any other relevant information:
Additional context
The text was updated successfully, but these errors were encountered:
Bug Description
when i use torch.compile with tensorrt as backend for a DFNs models like that:
and i invoke processor function like this
but when i use model to get output
_ = optimized_model(**inputs)
,it raise the error:so how should i solve this problem
To Reproduce
Steps to reproduce the behavior:
1.compile the model like motioned
2.process the input
3. invoke the model
Expected behavior
finish the forward process successfully
Environment
conda
,pip
,libtorch
, source): pipAdditional context
The text was updated successfully, but these errors were encountered: