Skip to content

Lowering stablehlo.dynamic_broadcast_in_dim to linalg mlir fails with dynamic shapes #2833

@kgao-tsv

Description

@kgao-tsv

What happened?

StableHLO dynamic_broadcast_in_dim operations fail to lower from StableHLO->Linalg for input tensors with dynamic dimensions from a MoE model like https://huggingface.co/Rakuten/RakutenAI-2.0-8x7B-instruct (where the model is exported to onnx mlir -> StableHLO), where the tensor has a dynamic first dimension because the number of tokens routed to each expert is determined at runtime.

error: op was not bufferized note: see current operation: %14 = "stablehlo.dynamic_broadcast_in_dim"(%0, %13) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x4096xf16>, tensor<2xi64>) -> tensor<?x4096xf16>

Snippet from stablehlo-opt
`//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0xc2507e0) {
%20 = "shape.shape_of"(%18) : (tensor<?x14336xf16>) -> tensor<2xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.broadcast'(0xc2c9560) {
%21 = "shape.broadcast"(%19, %20) : (tensor<2xindex>, tensor<2xindex>) -> tensor<2xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_broadcast_in_dim'(0xc2c9f10) {
%22 = "stablehlo.dynamic_broadcast_in_dim"(%12, %21) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16>

  • Fold {
    } -> FAILURE : unable to fold

  • Pattern : 'stablehlo.dynamic_broadcast_in_dim -> ()' {
    Trying to match "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter"
    "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter" result 0
    } -> FAILURE : pattern failed to match
    } -> FAILURE : no matched legalization pattern
    //===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_broadcast_in_dim'(0xc2ca010) {
%23 = "stablehlo.dynamic_broadcast_in_dim"(%18, %21) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16>

  • Fold {
    } -> FAILURE : unable to fold

  • Pattern : 'stablehlo.dynamic_broadcast_in_dim -> ()' {
    Trying to match "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter"
    "mlir::stablehlo::{anonymous}::DynamicBroadcastInDimOpToBroadcastConverter" result 0
    } -> FAILURE : pattern failed to match
    } -> FAILURE : no matched legalization pattern
    //===-------------------------------------------===//`

--

Coming from stablehlo.mlir (%18 to %21):
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32", llvm.target_triple = "aarch64-unknown-linux-gnu", "onnx-mlir.symbol-postfix" = "model"} { func.func @forward(%arg0: tensor<14336x4096xf16> {ccompiler.name = "input_0"}, %arg1: tensor<?x4096xf16> {ccompiler.name = "input_1"}, %arg2: tensor<14336x4096xf16> {ccompiler.name = "input_2"}, %arg3: tensor<4096x14336xf16> {ccompiler.name = "input_3"}) -> (tensor<?x4096xf16> {ccompiler.name = "12"}) attributes {llvm.emit_c_interface} { %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<14336x4096xf16>) -> tensor<4096x14336xf16> %1 = shape.shape_of %arg1 : tensor<?x4096xf16> -> tensor<2xindex> %2 = arith.index_cast %1 : tensor<2xindex> to tensor<2xi64> %3 = stablehlo.concatenate %2, dim = 0 : (tensor<2xi64>) -> tensor<2xi64> %4 = stablehlo.dynamic_broadcast_in_dim %arg1, %3, dims = [0, 1] : (tensor<?x4096xf16>, tensor<2xi64>) -> tensor<?x4096xf16> %5 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<4096x14336xf16>) -> tensor<4096x14336xf16> %6 = stablehlo.dot %4, %5 : (tensor<?x4096xf16>, tensor<4096x14336xf16>) -> tensor<?x14336xf16> %7 = stablehlo.logistic %6 : tensor<?x14336xf16> %8 = shape.shape_of %6 : tensor<?x14336xf16> -> tensor<2xindex> %9 = shape.shape_of %7 : tensor<?x14336xf16> -> tensor<2xindex> %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> %11 = stablehlo.dynamic_broadcast_in_dim %6, %10, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %12 = stablehlo.dynamic_broadcast_in_dim %7, %10, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %13 = stablehlo.multiply %11, %12 : tensor<?x14336xf16> %14 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<14336x4096xf16>) -> tensor<4096x14336xf16> %15 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<4096x14336xf16>) -> tensor<4096x14336xf16> %16 = stablehlo.dot %4, %15 : (tensor<?x4096xf16>, tensor<4096x14336xf16>) -> tensor<?x14336xf16> %17 = shape.shape_of %13 : tensor<?x14336xf16> -> tensor<2xindex> %18 = shape.shape_of %16 : tensor<?x14336xf16> -> tensor<2xindex> %19 = shape.broadcast %17, %18 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> %20 = stablehlo.dynamic_broadcast_in_dim %13, %19, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %21 = stablehlo.dynamic_broadcast_in_dim %16, %19, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xindex>) -> tensor<?x14336xf16> %22 = stablehlo.multiply %20, %21 : tensor<?x14336xf16> %23 = stablehlo.transpose %arg3, dims = [1, 0] : (tensor<4096x14336xf16>) -> tensor<14336x4096xf16> %24 = shape.shape_of %22 : tensor<?x14336xf16> -> tensor<2xindex> %25 = arith.index_cast %24 : tensor<2xindex> to tensor<2xi64> %26 = stablehlo.concatenate %25, dim = 0 : (tensor<2xi64>) -> tensor<2xi64> %27 = stablehlo.dynamic_broadcast_in_dim %22, %26, dims = [0, 1] : (tensor<?x14336xf16>, tensor<2xi64>) -> tensor<?x14336xf16> %28 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<14336x4096xf16>) -> tensor<14336x4096xf16> %29 = stablehlo.dot %27, %28 : (tensor<?x14336xf16>, tensor<14336x4096xf16>) -> tensor<?x4096xf16> return %29 : tensor<?x4096xf16> } }

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Version information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions