Skip to content

topk export to stablehlo error #4337

@Wheest

Description

@Wheest

With the following topk model:

import torch
import torch.nn as nn
from torch_mlir import fx

class SimpleTopKModel(nn.Module):
    def __init__(self, n_groups=2, group_dim=4, topk_groups=1):
        super().__init__()
        self.n_groups = n_groups
        self.group_dim = group_dim
        self.topk_groups = topk_groups

    def forward(self, x):
        # x: [batch, n_groups, group_dim]
        group_scores = x.amax(dim=-1)              # [batch, n_groups]
        indices = group_scores.topk(self.topk_groups, dim=-1)[1]
        return indices


# example usage
batch, n_groups, group_dim = 2, 3, 4
x = torch.randn(batch, n_groups, group_dim)

model = SimpleTopKModel(n_groups=n_groups, group_dim=group_dim, topk_groups=1)
indices = model(x)

print("indices shape:", indices.shape)
print("indices:", indices)

# stablehlo export
module = fx.export_and_import(
    model,
    x,
    output_type="stablehlo",
)
print(module)

We crash with a pretty unclear error.

Full trace

---------------------------------------------------------------------------

TorchMlirCompilerError                    Traceback (most recent call last)

[/tmp/ipython-input-2311695590.py](https://localhost:8080/#) in <cell line: 0>()
     28 
     29 # stablehlo export
---> 30 module = fx.export_and_import(
     31     model,
     32     x,

3 frames

[/usr/local/lib/python3.12/dist-packages/torch_mlir/fx.py](https://localhost:8080/#) in export_and_import(f, output_type, fx_importer, dynamic_shapes, strict, experimental_support_mutation, import_symbolic_shape_expressions, hooks, decomposition_table, func_name, enable_graph_printing, verbose, enable_ir_printing, backend_legal_ops, *args, **kwargs)
    122         )
    123 
--> 124     return _module_lowering(
    125         verbose,
    126         enable_ir_printing,

[/usr/local/lib/python3.12/dist-packages/torch_mlir/fx.py](https://localhost:8080/#) in _module_lowering(verbose, enable_ir_printing, output_type, torch_mod, extra_library_file_name, backend_legal_ops)
     65         enable_ir_printing=enable_ir_printing,
     66     )
---> 67     return lower_mlir_module(verbose, output_type, torch_mod)
     68 
     69 

[/usr/local/lib/python3.12/dist-packages/torch_mlir/compiler_utils.py](https://localhost:8080/#) in lower_mlir_module(verbose, output_type, module)
    214 
    215     elif output_type == OutputType.STABLEHLO:
--> 216         run_pipeline_with_repro_report(
    217             module,
    218             "builtin.module(torch-backend-to-stablehlo-backend-pipeline)",

[/usr/local/lib/python3.12/dist-packages/torch_mlir/compiler_utils.py](https://localhost:8080/#) in run_pipeline_with_repro_report(module, pipeline, description, enable_ir_printing)
    125             """
    126         trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
--> 127         raise TorchMlirCompilerError(trimmed_message) from e
    128     finally:
    129         sys.stderr = original_stderr

TorchMlirCompilerError: Lowering Torch Backend IR -> StableHLO Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-stablehlo-backend-pipeline)' /tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

We can export to the torch-mlir backend:

module {
  func.func @main(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,1],si64> {
    %int-1 = torch.constant.int -1
    %0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %false = torch.constant.bool false
    %1 = torch.aten.amax %arg0, %0, %false : !torch.vtensor<[2,3,4],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,3],f32>
    %int1 = torch.constant.int 1
    %int-1_0 = torch.constant.int -1
    %true = torch.constant.bool true
    %true_1 = torch.constant.bool true
    %values, %indices = torch.aten.topk %1, %int1, %int-1_0, %true, %true_1 : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,1],f32>, !torch.vtensor<[2,1],si64>
    return %indices : !torch.vtensor<[2,1],si64>
  }
}

Possible overlap with #4239

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions