-
Notifications
You must be signed in to change notification settings - Fork 606
Open
Description
With the following topk model:
import torch
import torch.nn as nn
from torch_mlir import fx
class SimpleTopKModel(nn.Module):
def __init__(self, n_groups=2, group_dim=4, topk_groups=1):
super().__init__()
self.n_groups = n_groups
self.group_dim = group_dim
self.topk_groups = topk_groups
def forward(self, x):
# x: [batch, n_groups, group_dim]
group_scores = x.amax(dim=-1) # [batch, n_groups]
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
return indices
# example usage
batch, n_groups, group_dim = 2, 3, 4
x = torch.randn(batch, n_groups, group_dim)
model = SimpleTopKModel(n_groups=n_groups, group_dim=group_dim, topk_groups=1)
indices = model(x)
print("indices shape:", indices.shape)
print("indices:", indices)
# stablehlo export
module = fx.export_and_import(
model,
x,
output_type="stablehlo",
)
print(module)
We crash with a pretty unclear error.
Full trace
---------------------------------------------------------------------------
TorchMlirCompilerError Traceback (most recent call last)
[/tmp/ipython-input-2311695590.py](https://localhost:8080/#) in <cell line: 0>()
28
29 # stablehlo export
---> 30 module = fx.export_and_import(
31 model,
32 x,
3 frames
[/usr/local/lib/python3.12/dist-packages/torch_mlir/fx.py](https://localhost:8080/#) in export_and_import(f, output_type, fx_importer, dynamic_shapes, strict, experimental_support_mutation, import_symbolic_shape_expressions, hooks, decomposition_table, func_name, enable_graph_printing, verbose, enable_ir_printing, backend_legal_ops, *args, **kwargs)
122 )
123
--> 124 return _module_lowering(
125 verbose,
126 enable_ir_printing,
[/usr/local/lib/python3.12/dist-packages/torch_mlir/fx.py](https://localhost:8080/#) in _module_lowering(verbose, enable_ir_printing, output_type, torch_mod, extra_library_file_name, backend_legal_ops)
65 enable_ir_printing=enable_ir_printing,
66 )
---> 67 return lower_mlir_module(verbose, output_type, torch_mod)
68
69
[/usr/local/lib/python3.12/dist-packages/torch_mlir/compiler_utils.py](https://localhost:8080/#) in lower_mlir_module(verbose, output_type, module)
214
215 elif output_type == OutputType.STABLEHLO:
--> 216 run_pipeline_with_repro_report(
217 module,
218 "builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
[/usr/local/lib/python3.12/dist-packages/torch_mlir/compiler_utils.py](https://localhost:8080/#) in run_pipeline_with_repro_report(module, pipeline, description, enable_ir_printing)
125 """
126 trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
--> 127 raise TorchMlirCompilerError(trimmed_message) from e
128 finally:
129 sys.stderr = original_stderr
TorchMlirCompilerError: Lowering Torch Backend IR -> StableHLO Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-stablehlo-backend-pipeline)' /tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
We can export to the torch-mlir backend:
module {
func.func @main(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,1],si64> {
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%false = torch.constant.bool false
%1 = torch.aten.amax %arg0, %0, %false : !torch.vtensor<[2,3,4],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,3],f32>
%int1 = torch.constant.int 1
%int-1_0 = torch.constant.int -1
%true = torch.constant.bool true
%true_1 = torch.constant.bool true
%values, %indices = torch.aten.topk %1, %int1, %int-1_0, %true, %true_1 : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,1],f32>, !torch.vtensor<[2,1],si64>
return %indices : !torch.vtensor<[2,1],si64>
}
}
Possible overlap with #4239
Metadata
Metadata
Assignees
Labels
No labels