Skip to content

error: 'torch.aten.copy_' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got '!torch.vtensor #4346

@lionsky

Description

@lionsky

How to reproduce?

python test.py

import torch
import torch.nn as nn
from torch.fx.experimental.proxy_tensor import make_fx
from torch.func import functionalize

import torch_mlir
from torch_mlir import fx as fx
from torch_mlir.compiler_utils import OutputType
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    context = torch_mlir.ir.Context()
    torch_d.register_dialect(context)
    gm = make_fx(functionalize(gm), decomposition_table=get_decomposition_table())(*example_inputs)
    fx_importer = FxImporter(context=context, hooks=None)
    fx_importer.import_stateless_graph(gm.graph, func_name="forwards")
    linalg_mlir = fx._module_lowering(
        True,
        False,
        OutputType.LINALG_ON_TENSORS,
        fx_importer.module,
        backend_legal_ops=None,
    )

    return gm.forward

class NormalizeRouting(nn.Module):
    def forward(self, routing_weights: torch.Tensor) -> torch.Tensor:
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        return

x = torch.randn(2, 4)
compiled_fn = torch.compile(NormalizeRouting(), backend=custom_backend)


compiled_fn(x)

The output is

error: 'torch.aten.copy_' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got '!torch.vtensor<[2,4],f32>'

====================
TorchFX IR
"builtin.module"() ({
  "func.func"() <{function_type = (!torch.vtensor<[2,4],f32>) -> (), sym_name = "forwards"}> ({
  ^bb0(%arg0: !torch.vtensor<[2,4],f32>):
    %0 = "torch.constant.int"() <{value = -1 : i64}> : () -> !torch.int
    %1 = "torch.prim.ListConstruct"(%0) : (!torch.int) -> !torch.list<int>
    %2 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool
    %3 = "torch.constant.none"() : () -> !torch.none
    %4 = "torch.aten.sum.dim_IntList"(%arg0, %1, %2, %3) : (!torch.vtensor<[2,4],f32>, !torch.list<int>, !torch.bool, !torch.none) -> !torch.vtensor<[2,1],f32>
    %5 = "torch.aten.div.Tensor"(%arg0, %4) : (!torch.vtensor<[2,4],f32>, !torch.vtensor<[2,1],f32>) -> !torch.vtensor<[2,4],f32>
    %6 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
    %7 = "torch.aten.copy_"(%arg0, %5, %6) : (!torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>, !torch.bool) -> !torch.vtensor<[2,4],f32>
    "func.return"() : () -> ()
  }) : () -> ()
}) : () -> ()

Traceback (most recent call last):
  File "./workspace/demos/test.py", line 41, in <module>
    compiled_fn(x)
  File "./venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "./venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/torch/__init__.py", line 2425, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./workspace/demos/test.py", line 22, in custom_backend
    linalg_mlir = fx._module_lowering(
                  ^^^^^^^^^^^^^^^^^^^^
  File "./workspace/torch-mlir/build/python_packages/torch_mlir/torch_mlir/fx.py", line 61, in _module_lowering
    run_pipeline_with_repro_report(
  File "./workspace/torch-mlir/build/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch._dynamo.exc.BackendCompilerFailed: backend='custom_backend' raised:
TorchMlirCompilerError: Lowering TorchFX IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ extra-library=})' /tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Does anybody know the root cause?
It seems the torch-mlir fails to import in-place op (aten.copy_).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions