Skip to content

[GlobalOpt] SinkTransposeThroughExtractSlice pattern failed in PropagateLinalgTranspose pass #22687

@ziliangzl

Description

@ziliangzl

What happened?

Running the following IR:

util.func public @rank_reduced_extract_transposed(%arg0: tensor<80x80x1203xf16>) -> tensor<1203x5x80xf16> {
  %0 = tensor.empty() : tensor<1203x80x80xf16>
  %1 = tensor.empty() : tensor<1203x80x80xf16>
  %2 = tensor.empty() : tensor<1203x80x80xf16>
  %transposed = linalg.transpose ins(%arg0 : tensor<80x80x1203xf16>) outs(%0 : tensor<1203x80x80xf16>) permutation = [2, 0, 1]
  %expanded = tensor.expand_shape %transposed [[0, 1], [2], [3, 4]] output_shape [1, 1203, 80, 80, 1] : tensor<1203x80x80xf16> into tensor<1x1203x80x80x1xf16>
  %extracted_slice = tensor.extract_slice %expanded[0, 0, 0, 0, 0] [1, 1203, 5, 80, 1] [1, 1, 1, 1, 1] : tensor<1x1203x80x80x1xf16> to tensor<1203x5x80xf16>
  util.return %extracted_slice : tensor<1203x5x80xf16>
}

Run the following command:

../iree-build/tools/iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose))" propagate_extract_slice.mlir

and enabling -debug, I get the following error:

--- After sinking transpose ops down ---
expected type to be 'tensor<5x80x1x1x1203xf16>' or a rank-reduced version. (size mismatch) 
[mlir-asm-printer AsmPrinter.cpp:2073 1] util.func' failed to verify and will be printed in generic form
"util.func"() <{function_type = (tensor<80x80x1203xf16>) -> tensor<1203x5x80xf16>, sym_name = "rank_reduced_extract_transposed_unit_dim", sym_visibility = "public", tied_operands = [-1 : index]}> ({
^bb0(%arg0: tensor<80x80x1203xf16>):
  %0 = "tensor.expand_shape"(%arg0) <{reassociation = [[0], [1, 2], [3, 4]], static_output_shape = array<i64: 80, 80, 1, 1, 1203>}> : (tensor<80x80x1203xf16>) -> tensor<80x80x1x1x1203xf16>
  %1 = "tensor.extract_slice"(%0) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 5, 80, 1, 1, 1203>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<80x80x1x1x1203xf16>) -> tensor<80x1203x1203xf16>
  %2 = "tensor.empty"() : () -> tensor<1203x80x80xf16>
  %3 = "linalg.transpose"(%1, %2) <{permutation = array<i64: 1, 0, 0>}> ({
  ^bb0(%arg1: f16, %arg2: f16):
    "linalg.yield"(%arg1) : (f16) -> ()
  }) : (tensor<80x1203x1203xf16>, tensor<1203x80x80xf16>) -> tensor<1203x80x80xf16>
  "util.return"(%3) : (tensor<1203x80x80xf16>) -> ()
}) : () -> ()

propagate_extract_slice.mlir:30:22: error: expected type to be 'tensor<5x80x1x1x1203xf16>' or a rank-reduced version. (size mismatch) 
  %extracted_slice = tensor.extract_slice %expanded[0, 0, 0, 0, 0] [1, 1203, 5, 80, 1] [1, 1, 1, 1, 1] : tensor<1x1203x80x80x1xf16> to tensor<1203x5x80xf16>
                     ^
propagate_extract_slice.mlir:30:22: note: see current operation: %1 = "tensor.extract_slice"(%0) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 5, 80, 1, 1, 1203>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<80x80x1x1x1203xf16>) -> tensor<80x1203x1203xf16>
[pass-manager Pass.cpp:683 1] Pipeline failed for pass 'PropagateLinalgTransposePass' on operation 'expected type to be 'tensor<5x80x1x1x1203xf16>' or a rank-reduced version. (size mismatch) 
[mlir-asm-printer AsmPrinter.cpp:2073 1] util.func' failed to verify and will be printed in generic form
"util.func"() <{function_type = (tensor<80x80x1203xf16>) -> tensor<1203x5x80xf16>, sym_name = "rank_reduced_extract_transposed_unit_dim", sym_visibility = "public", tied_operands = [-1 : index]}> ({...}) : () -> ()'
[pass-manager Pass.cpp:683 1] Pipeline failed for pass 'mlir::detail::OpToOpPassAdaptor' on operation 'expected type to be 'tensor<5x80x1x1x1203xf16>' or a rank-reduced version. (size mismatch) 
[mlir-asm-printer AsmPrinter.cpp:2073 1] builtin.module' failed to verify and will be printed in generic form
"builtin.module"() ({...}) : () -> ()'
[pass-manager Pass.cpp:1102 2] PassManager run completed with result: failure

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐞Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions