-
Notifications
You must be signed in to change notification settings - Fork 606
Open
Description
How to reproduce?
python test.py
import torch
import torch.nn as nn
from torch.fx.experimental.proxy_tensor import make_fx
from torch.func import functionalize
import torch_mlir
from torch_mlir import fx as fx
from torch_mlir.compiler_utils import OutputType
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter
from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
context = torch_mlir.ir.Context()
torch_d.register_dialect(context)
gm = make_fx(functionalize(gm), decomposition_table=get_decomposition_table())(*example_inputs)
fx_importer = FxImporter(context=context, hooks=None)
fx_importer.import_stateless_graph(gm.graph, func_name="forwards")
linalg_mlir = fx._module_lowering(
True,
False,
OutputType.LINALG_ON_TENSORS,
fx_importer.module,
backend_legal_ops=None,
)
return gm.forward
class NormalizeRouting(nn.Module):
def forward(self, routing_weights: torch.Tensor) -> torch.Tensor:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
return
x = torch.randn(2, 4)
compiled_fn = torch.compile(NormalizeRouting(), backend=custom_backend)
compiled_fn(x)
The output is
error: 'torch.aten.copy_' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got '!torch.vtensor<[2,4],f32>'
====================
TorchFX IR
"builtin.module"() ({
"func.func"() <{function_type = (!torch.vtensor<[2,4],f32>) -> (), sym_name = "forwards"}> ({
^bb0(%arg0: !torch.vtensor<[2,4],f32>):
%0 = "torch.constant.int"() <{value = -1 : i64}> : () -> !torch.int
%1 = "torch.prim.ListConstruct"(%0) : (!torch.int) -> !torch.list<int>
%2 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool
%3 = "torch.constant.none"() : () -> !torch.none
%4 = "torch.aten.sum.dim_IntList"(%arg0, %1, %2, %3) : (!torch.vtensor<[2,4],f32>, !torch.list<int>, !torch.bool, !torch.none) -> !torch.vtensor<[2,1],f32>
%5 = "torch.aten.div.Tensor"(%arg0, %4) : (!torch.vtensor<[2,4],f32>, !torch.vtensor<[2,1],f32>) -> !torch.vtensor<[2,4],f32>
%6 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
%7 = "torch.aten.copy_"(%arg0, %5, %6) : (!torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>, !torch.bool) -> !torch.vtensor<[2,4],f32>
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()
Traceback (most recent call last):
File "./workspace/demos/test.py", line 41, in <module>
compiled_fn(x)
File "./venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
raise BackendCompilerFailed(
File "./venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/torch/__init__.py", line 2425, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./workspace/demos/test.py", line 22, in custom_backend
linalg_mlir = fx._module_lowering(
^^^^^^^^^^^^^^^^^^^^
File "./workspace/torch-mlir/build/python_packages/torch_mlir/torch_mlir/fx.py", line 61, in _module_lowering
run_pipeline_with_repro_report(
File "./workspace/torch-mlir/build/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch._dynamo.exc.BackendCompilerFailed: backend='custom_backend' raised:
TorchMlirCompilerError: Lowering TorchFX IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ extra-library=})' /tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Does anybody know the root cause?
It seems the torch-mlir fails to import in-place op (aten.copy_).
Metadata
Metadata
Assignees
Labels
No labels