-
Notifications
You must be signed in to change notification settings - Fork 165
Open
Description
Request description
// RUN: stablehlo-opt %s --stablehlo-refine-arguments=types="tensor<512xi32>,tensor<16000x512xf32>" --stablehlo-refine-shapes
//// ERROR - composite doesn't match decomp
// RUN: stablehlo-opt %s --stablehlo-refine-arguments=types="tensor<512xi32>" --stablehlo-legalize-composite-to-call --stablehlo-refine-shapes
//// OK - we handle CallOp properly in shape refinement
// CHECK-LABEL: func @main
// CHECK-SAME: (%arg0: tensor<512xi32>) -> tensor<512xi32>
func.func @main(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xi32>) -> tensor<i32>
%1 = call @main_wrapped(%0, %arg0) : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
// CHECK-LITERAL: func @main_wrapped
func.func @main_wrapped(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?xi32>) -> tensor<?xi32> {
// CHECK: stablehlo.composite "my.embedding_lookup" %arg1 {decomposition = @my.embedding_lookup.impl} : (tensor<512xi32>) -> tensor<?xi32>
%0 = stablehlo.composite "my.mul2" %arg0, %arg1 {decomposition = @my.mul2.impl} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
func.func private @my.mul2.impl(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%c_2 = stablehlo.constant dense<2> : tensor<i32>
%shape = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%13 = stablehlo.dynamic_broadcast_in_dim %c_2, %shape, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<?xi32>
%14 = stablehlo.multiply %arg1, %13 : tensor<?xi32>
return %14 : tensor<?xi32>
}
Metadata
Metadata
Assignees
Labels
No labels