slight code reorg and bug correction for cross_compile#3472
Conversation
| # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows | ||
| trt_node = gm.graph.call_function( | ||
| torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, | ||
| (trt_module_node.args, *engine_info), |
There was a problem hiding this comment.
Do we still need to unpack this list?
There was a problem hiding this comment.
We would still need to unpack the list. Else while loading in windows it shows
File "C:\Users\abose\Documents\work\TensorRT\torchTRT\Lib\site-packages\torch\_export\serde\serialize.py", line 2258, in deserialize_inputs
args.append(actual_args[schema_arg.name])
~~~~~~~~~~~^^^^^^^^^^^^^^^^^
KeyError: 'name'
py/torch_tensorrt/runtime/_utils.py
Outdated
| serialized_hardware_compatible: str, | ||
| serialized_metadata: str, | ||
| serialized_target_platform: str, | ||
| serialized_require_output_allocator: str, |
There was a problem hiding this comment.
Move this placeholder op to runtime/meta_ops
| getitem_nodes = trt_node.users | ||
| for idx, getitem_node in enumerate(getitem_nodes): | ||
| getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
| no_op_placeholder_node.replace_all_uses_with(trt_node) |
There was a problem hiding this comment.
Can you add a multi output testcase to the cross compile tests?
| getitem_nodes = trt_node.users | ||
| for idx, getitem_node in enumerate(getitem_nodes): | ||
| getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
|
|
There was a problem hiding this comment.
@narendasan this is the part which should address the bug
| def forward(self, a, b): | ||
| return torch.add(a, b) | ||
|
|
||
| print("here") |
3f8ab4c to
2934660
Compare
|
@bowang007 on linux converter tests I see- Would you know what is going wrong? |
Hi @apbose , When you do the cross-compile, what is the sm version that you are compiling into? |
|
Hmm @bowang007 are you suggesting the above wrt to the linux tests or the windows test? The error seems to be coming specifically in pytorch/TensorRT/tree/main/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py tests in linux |
Hi @apbose , |
@apbose can we do something like turning off these tests for other platforms for now? |
2934660 to
f8f0f55
Compare
Addresses the following for the cross_compile_for_windows feature-
cross_compile_flagwithcross_compile_module