Skip to content

Commit b748d89

Browse files
committed
Equalize rank for tosa min and max
1 parent 26be67e commit b748d89

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

src/Conversion/ONNXToTOSA/Math/Elementwise.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,16 @@ class ONNXClipOpLoweringToTOSA : public OpConversionPattern<ONNXClipOp> {
347347
rewriter.getF32FloatAttr(maxFloat));
348348
} else {
349349
if (!isNoneValue(min)) {
350+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), res, min)
351+
.failed())
352+
return failure();
350353
res = tosa::CreateOpAndInfer<mlir::tosa::MaximumOp>(
351354
rewriter, op->getLoc(), op.getType(), res, min);
352355
}
353356
if (!isNoneValue(max)) {
357+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), res, max)
358+
.failed())
359+
return failure();
354360
res = tosa::CreateOpAndInfer<mlir::tosa::MinimumOp>(
355361
rewriter, op->getLoc(), op.getType(), res, max);
356362
}

src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
#ifndef ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H
1818
#define ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H
1919

20-
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
21-
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
20+
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
21+
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
22+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2223
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
2324
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
2425
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project

test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -589,12 +589,14 @@ func.func @test_tanh(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
589589
func.func @test_clip(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<3xi32> {
590590
%0 = "onnx.Clip"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
591591
return %0 : tensor<3xi32>
592-
// CHECK-LABEL: func @test_clip
593-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>, [[PARAM_1_:%.+]]: tensor<i32>, [[PARAM_2_:%.+]]: tensor<i32>) -> tensor<3xi32>
594-
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.maximum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xi32>, tensor<i32>) -> tensor<3xi32>
595-
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.minimum [[VAR_0_]], [[PARAM_2_]] : (tensor<3xi32>, tensor<i32>) -> tensor<3xi32>
596-
// CHECK-NEXT: return [[VAR_1_]] : tensor<3xi32>
597-
// CHECK-NEXT: }
592+
// CHECK-LABEL: func.func @test_clip
593+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>, [[PARAM_1_:%.+]]: tensor<i32>, [[PARAM_2_:%.+]]: tensor<i32>) -> tensor<3xi32> {
594+
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
595+
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
596+
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
597+
// CHECK: [[VAR_3_:%.+]] = tosa.minimum [[VAR_1_]], [[VAR_2_]] : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
598+
// CHECK: return [[VAR_3_]] : tensor<3xi32>
599+
// CHECK: }
598600
}
599601

600602
// -----
@@ -604,11 +606,12 @@ func.func @test_clip_default_min_f32(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -
604606
%cst = "onnx.NoValue"() {value} : () -> none
605607
%0 = "onnx.Clip"(%arg0, %cst, %arg1) : (tensor<3xf32>, none, tensor<f32>) -> tensor<3xf32>
606608
return %0 : tensor<3xf32>
607-
// CHECK-LABEL: func @test_clip_default_min_f32
608-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<3xf32>
609-
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.minimum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xf32>, tensor<f32>) -> tensor<3xf32>
610-
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xf32>
611-
// CHECK-NEXT: }
609+
// CHECK-LABEL: func.func @test_clip_default_min_f32
610+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<3xf32> {
611+
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
612+
// CHECK: [[VAR_1_:%.+]] = tosa.minimum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
613+
// CHECK: return [[VAR_1_]] : tensor<3xf32>
614+
// CHECK: }
612615
}
613616

614617
// -----
@@ -618,11 +621,12 @@ func.func @test_clip_default_max_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<bf16>
618621
%cst = "onnx.NoValue"() {value} : () -> none
619622
%0 = "onnx.Clip"(%arg0, %arg1, %cst) : (tensor<3xbf16>, tensor<bf16>, none) -> tensor<3xbf16>
620623
return %0 : tensor<3xbf16>
621-
// CHECK-LABEL: func @test_clip_default_max_bf16
622-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>, [[PARAM_1_:%.+]]: tensor<bf16>) -> tensor<3xbf16>
623-
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.maximum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xbf16>, tensor<bf16>) -> tensor<3xbf16>
624-
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16>
625-
// CHECK-NEXT: }
624+
// CHECK-LABEL: func.func @test_clip_default_max_bf16
625+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>, [[PARAM_1_:%.+]]: tensor<bf16>) -> tensor<3xbf16> {
626+
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1>} : (tensor<bf16>) -> tensor<1xbf16>
627+
// CHECK: [[VAR_1_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xbf16>, tensor<1xbf16>) -> tensor<3xbf16>
628+
// CHECK: return [[VAR_1_]] : tensor<3xbf16>
629+
// CHECK: }
626630
}
627631

628632
// -----
@@ -648,14 +652,14 @@ func.func @test_clip_constant_minimum_maximum_non_splat(%arg0: tensor<3xi32>) ->
648652
%cst2 = "onnx.Constant"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32>
649653
%0 = "onnx.Clip"(%arg0, %cst1, %cst2) : (tensor<3xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
650654
return %0 : tensor<3xi32>
651-
// CHECK-LABEL: func @test_clip_constant_minimum_maximum_non_splat
652-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32>
653-
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[-1, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
654-
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
655-
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
656-
// CHECK-NEXT: [[VAR_3_:%.+]] = tosa.minimum [[VAR_2_]], [[VAR_1_]] : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
657-
// CHECK-NEXT: return [[VAR_3_]] : tensor<3xi32>
658-
// CHECK-NEXT: }
655+
// CHECK-LABEL: func.func @test_clip_constant_minimum_maximum_non_splat
656+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32> {
657+
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[-1, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
658+
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
659+
// CHECK: [[VAR_2_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
660+
// CHECK: [[VAR_3_:%.+]] = tosa.minimum [[VAR_2_]], [[VAR_1_]] : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
661+
// CHECK: return [[VAR_3_]] : tensor<3xi32>
662+
// CHECK: }
659663
}
660664

661665
func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> {

0 commit comments

Comments
 (0)