-
Notifications
You must be signed in to change notification settings - Fork 165
Description
What happened?
StableHLO dynamic_broadcast_in_dim operations fail to lower from StableHLO->Linalg for input tensors with dynamic dimensions from a MoE model like https://huggingface.co/Rakuten/RakutenAI-2.0-8x7B-instruct (where the model is exported to onnx mlir -> StableHLO), where the tensor has a dynamic first dimension because the number of tokens routed to each expert is determined at runtime.
error: op was not bufferized note: see current operation: %14 = "stablehlo.dynamic_broadcast_in_dim"(%0, %13) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x4096xf16>, tensor<2xi64>) -> tensor<?x4096xf16>
Snippet from stablehlo-opt
`//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0xc2507e0) {
%20 = "shape.shape_of"(%18) : (tensor<?x14336xf16>) -> tensor<2xindex>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'shape.broadcast'(0xc2c9560) {
%21 = "shape.broadcast"(%19, %20) : (tensor<2xindex>, tensor<2xindex>) -> tensor<2xindex>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_broadcast_in_dim'(0xc2c9f10) {
%22 = "stablehlo.dynamic_broadcast_in_dim"(%12, %21) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16>
-
Fold {
} -> FAILURE : unable to fold -
Pattern : 'stablehlo.dynamic_broadcast_in_dim -> ()' {
Trying to match "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter"
"mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter" result 0
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_broadcast_in_dim'(0xc2ca010) {
%23 = "stablehlo.dynamic_broadcast_in_dim"(%18, %21) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16>
-
Fold {
} -> FAILURE : unable to fold -
Pattern : 'stablehlo.dynamic_broadcast_in_dim -> ()' {
Trying to match "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter"
"mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter" result 0
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//`
--
Coming from stablehlo.mlir (%18 to %21):
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32", llvm.target_triple = "aarch64-unknown-linux-gnu", "onnx-mlir.symbol-postfix" = "model"} { func.func @forward(%arg0: tensor<14336x4096xf16> {ccompiler.name = "input_0"}, %arg1: tensor<?x4096xf16> {ccompiler.name = "input_1"}, %arg2: tensor<14336x4096xf16> {ccompiler.name = "input_2"}, %arg3: tensor<4096x14336xf16> {ccompiler.name = "input_3"}) -> (tensor<?x4096xf16> {ccompiler.name = "12"}) attributes {llvm.emit_c_interface} { %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<14336x4096xf16>) -> tensor<4096x14336xf16> %1 = shape.shape_of %arg1 : tensor<?x4096xf16> -> tensor<2xindex> %2 = arith.index_cast %1 : tensor<2xindex> to tensor<2xi64> %3 = stablehlo.concatenate %2, dim = 0 : (tensor<2xi64>) -> tensor<2xi64> %4 = stablehlo.dynamic_broadcast_in_dim %arg1, %3, dims = [0, 1] : (tensor<?x4096xf16>, tensor<2xi64>) -> tensor<?x4096xf16> %5 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<4096x14336xf16>) -> tensor<4096x14336xf16> %6 = stablehlo.dot %4, %5 : (tensor<?x4096xf16>, tensor<4096x14336xf16>) -> tensor<?x14336xf16> %7 = stablehlo.logistic %6 : tensor<?x14336xf16> %8 = shape.shape_of %6 : tensor<?x14336xf16> -> tensor<2xindex> %9 = shape.shape_of %7 : tensor<?x14336xf16> -> tensor<2xindex> %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> %11 = stablehlo.dynamic_broadcast_in_dim %6, %10, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %12 = stablehlo.dynamic_broadcast_in_dim %7, %10, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %13 = stablehlo.multiply %11, %12 : tensor<?x14336xf16> %14 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<14336x4096xf16>) -> tensor<4096x14336xf16> %15 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<4096x14336xf16>) -> tensor<4096x14336xf16> %16 = stablehlo.dot %4, %15 : (tensor<?x4096xf16>, tensor<4096x14336xf16>) -> tensor<?x14336xf16> %17 = shape.shape_of %13 : tensor<?x14336xf16> -> tensor<2xindex> %18 = shape.shape_of %16 : tensor<?x14336xf16> -> tensor<2xindex> %19 = shape.broadcast %17, %18 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> %20 = stablehlo.dynamic_broadcast_in_dim %13, %19, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %21 = stablehlo.dynamic_broadcast_in_dim %16, %19, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %22 = stablehlo.multiply %20, %21 : tensor<?x14336xf16> %23 = stablehlo.transpose %arg3, dims = [1, 0] : (tensor<4096x14336xf16>) -> tensor<14336x4096xf16> %24 = shape.shape_of %22 : tensor<?x14336xf16> -> tensor<2xindex> %25 = arith.index_cast %24 : tensor<2xindex> to tensor<2xi64> %26 = stablehlo.concatenate %25, dim = 0 : (tensor<2xi64>) -> tensor<2xi64> %27 = stablehlo.dynamic_broadcast_in_dim %22, %26, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xi64>) -> tensor<?x14336xf16> %28 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<14336x4096xf16>) -> tensor<14336x4096xf16> %29 = stablehlo.dot %27, %28 : (tensor<?x14336xf16>, tensor<14336x4096xf16>) -> tensor<?x4096xf16> return %29 : tensor<?x4096xf16> } }
Steps to reproduce your issue
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
Version information
No response