Skip to content

🐛 [Bug] cannot convert x.to(torch.uint8) #3247

@braindevices

Description

@braindevices

Bug Description

it report _to_copy is not supported, when we use things like x.to(torch.uint8)

torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::_to_copy not currently supported!

While executing %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%mul,), kwargs = {dtype: torch.uint8, _itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f3a1e94bc30>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3a20b3d3b0>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3a20aa39b0>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {})}})

But actually tensorrt support this operation, if we convert to onnx then load onnx in tensorrt

To Reproduce

Steps to reproduce the behavior:

  1. define a dummy model contain to()
import torch
from torch import nn
class dummy_t(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        return x.clamp_(0, 1).mul_(255).to(dtype=torch.uint8)
xs = [torch.randn((1,3,5,7)).cuda()]
exported = torch.export.export(
    dummy_t().cuda(),
    args=tuple(xs)
)
  1. run trt export, it will fail:
torch_tensorrt.dynamo.convert_module_to_trt_engine(
    exported,
    assume_dynamic_shape_support=False,
    inputs=tuple(xs),
    use_python_runtime=False,
    enabled_precisions={torch.float32},
    use_fast_partitioner=False,
    debug=True,
    min_block_size=1,
    require_full_compilation=True
)
  1. run onnx->trt it works fine
from tempfile import NamedTemporaryFile
import onnx
import tensorrt as trt
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
import io
output_names = ['output0']
input_names = ["x"]
with NamedTemporaryFile() as f:
    onnx_program = torch.onnx.export(
        dummy_t().cuda(),
        tuple(xs),
        f.name,
        verbose=False,
        opset_version=20,
        do_constant_folding=True,  # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
        # https://github.com/pytorch/pytorch/issues/73843
        input_names=input_names,
        output_names=output_names,
        dynamo=False,
        training=torch.onnx.TrainingMode.EVAL # we can export trainable model!
    )
    model_onnx: onnx.ModelProto
    model_onnx = onnx.load(f.name)
workspace = 10*1024**2
trt_logger = trt.Logger(trt.Logger.INFO)
trt_logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(trt_logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=workspace)
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, trt_logger)
if not parser.parse(model_onnx.SerializeToString()):
    raise RuntimeError(f'failed to load ONNX model')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]

with builder.build_serialized_network(network, config) as engine, io.BytesIO() as engine_bytes: # type: ignore
    engine_bytes.write(engine)
    engine_bytes.seek(0)
    serialized_trt_engine = engine_bytes.read()

pt_trt = PythonTorchTensorRTModule(
    serialized_trt_engine,
    input_names=input_names,
    output_names=["output0"]
)
print(pt_trt(*xs).dtype)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT 2.4.0
  • PyTorch Version (e.g. 1.0): 2.4.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Almalinux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Python version: 3.11
  • CUDA version: 12.3
  • GPU models and configuration: RTX4k

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions