Skip to content

Commit d40ef12

Browse files
authored
Merge pull request #368 from Xilinx/jrickert.allow_lowering_of_int4_to_tosa
Allow lowering of (u)int4 to tosa.
2 parents 315aa0f + 1a18271 commit d40ef12

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

src/Conversion/ONNXToTOSA/Math/Elementwise.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,11 @@ class ONNXCastOpLoweringToTOSA : public OpConversionPattern<ONNXCastOp> {
378378
if (!inputTy) {
379379
return rewriter.notifyMatchFailure(op, "expected valid input type");
380380
}
381+
if (isa<FloatType>(inputTy.getElementType()) &&
382+
resultTy.getElementType().isUnsignedInteger()) {
383+
return rewriter.notifyMatchFailure(
384+
op, "TOSA does not support cast from float to unsigned integer");
385+
}
381386
if (isa<FloatType>(inputTy.getElementType()) &&
382387
isa<IntegerType>(resultTy.getElementType())) {
383388
// ONNX.Cast has truncating behavior, and tosa.cast has rounds

src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ inline bool isTOSABool(mlir::Type type) {
9393

9494
inline bool isTOSAInt(mlir::Type type) {
9595
mlir::IntegerType intType = mlir::dyn_cast<mlir::IntegerType>(type);
96-
std::set<unsigned> intWidth{1, 8, 16, 32, 48, 64};
96+
// Int 4 is not a tosa int, but supported by tosa.mlir
97+
std::set<unsigned> intWidth{1, 4, 8, 16, 32, 48, 64};
9798
return intType && (intType.isSignless() || intType.isUnsignedInteger()) &&
9899
(intWidth.find(intType.getWidth()) != intWidth.end());
99100
}

test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,37 @@ func.func @test_cast_int4_and_uint4_to_from_int8_uint8(%arg0: tensor<1xi4>, %arg
2525
%2 = "onnx.Cast"(%arg1) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8>
2626
%3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4>
2727
onnx.Return %1, %3 : tensor<1xi4>, tensor<1xui4>
28-
// CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8(
29-
// TOSA does not support int4 casting
30-
// CHECK-NOT: tosa.cast
28+
// CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8
29+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4>, [[PARAM_1_:%.+]]: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) {
30+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<1xi4>) -> tensor<1xi8>
31+
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.cast [[VAR_0_]] : (tensor<1xi8>) -> tensor<1xi4>
32+
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.cast [[PARAM_1_]] : (tensor<1xui4>) -> tensor<1xui8>
33+
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<1xui8>) -> tensor<1xui4>
34+
// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4>
35+
// CHECK: }
36+
}
37+
38+
// -----
39+
40+
func.func @test_cast_int4_and_uint4_to_float_and_back(%arg0: tensor<1xi4>, %arg1: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) {
41+
%0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<1xi4>) -> tensor<1xf32>
42+
%1 = "onnx.Cast"(%0) {saturate = 1 : si64, to = i4} : (tensor<1xf32>) -> tensor<1xi4>
43+
%2 = "onnx.Cast"(%arg1) {saturate = 1 : si64, to = f32} : (tensor<1xui4>) -> tensor<1xf32>
44+
%3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = ui4} : (tensor<1xf32>) -> tensor<1xui4>
45+
onnx.Return %1, %3 : tensor<1xi4>, tensor<1xui4>
46+
// CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_float_and_back
47+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4>, [[PARAM_1_:%.+]]: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) {
48+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<1xi4>) -> tensor<1xf32>
49+
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
50+
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.greater_equal [[VAR_0_]], [[VAR_1_]] : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
51+
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.floor [[VAR_0_]] : (tensor<1xf32>) -> tensor<1xf32>
52+
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.ceil [[VAR_0_]] : (tensor<1xf32>) -> tensor<1xf32>
53+
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.select [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
54+
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1xf32>) -> tensor<1xi4>
55+
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.cast [[PARAM_1_]] : (tensor<1xui4>) -> tensor<1xf32>
56+
// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_7_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xf32>) -> tensor<1xui4>
57+
// CHECK: onnx.Return [[VAR_6_]], [[VAR_8_]] : tensor<1xi4>, tensor<1xui4>
58+
// CHECK: }
3159
}
3260

3361
// -----

test/mlir/conversion/onnx_to_tosa/Tensor/Constant.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@ func.func @test_int_dense() -> tensor<2xi8> {
4545

4646
// -----
4747

48+
func.func @test_int4_dense() -> tensor<2xi4> {
49+
%0 = "onnx.Constant"() {value = dense<[-1, -2]> : tensor<2xi4>} : () -> tensor<2xi4>
50+
return %0 : tensor<2xi4>
51+
// CHECK-LABEL: @test_int4_dense() -> tensor<2xi4>
52+
// CHECK: "tosa.const"() <{value = dense<[-1, -2]> : tensor<2xi4>}> : () -> tensor<2xi4>
53+
}
54+
55+
// -----
56+
57+
func.func @test_uint4_dense() -> tensor<2xui4> {
58+
%0 = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xui4>} : () -> tensor<2xui4>
59+
return %0 : tensor<2xui4>
60+
// CHECK-LABEL: @test_uint4_dense() -> tensor<2xui4>
61+
// CHECK: "tosa.const"() <{value = dense<[1, 2]> : tensor<2xui4>}> : () -> tensor<2xui4>
62+
}
63+
64+
// -----
65+
4866
func.func @test_bool_single() -> tensor<i1> {
4967
%0 = "onnx.Constant"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
5068
return %0 : tensor<i1>

0 commit comments

Comments
 (0)