Skip to content

Commit 7a7beda

Browse files
added support for Celu op (#3139)
Signed-off-by: logeshwaranmcw <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent 66e82a9 commit 7a7beda

File tree

5 files changed

+153
-11
lines changed

5 files changed

+153
-11
lines changed

docs/SupportedONNXOps-cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 22. Limitatio
4141
| **CastMap** |none | | | |
4242
| **CategoryMapper** |none | | | |
4343
| **Ceil** |6 - * | | |
44-
| **Celu** |none | | | |
44+
| **Celu** |12 - * | | | |
4545
| **CenterCropPad** |none | | | |
4646
| **Clip** |6 - * |No support for short integers. | |
4747
| **Col2Im** |none | | | |

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,56 @@ Value emitScalarOpFor<ONNXReluOp>(ConversionPatternRewriter &rewriter,
752752
return create.math.max(zero, operand);
753753
}
754754

755+
//===----------------------------------------------------------------------===//
756+
// Scalar unary ops for lowering ONNXCeLUOp
757+
//===----------------------------------------------------------------------===//
758+
759+
template <>
760+
struct ScalarOp<ONNXCeluOp> {
761+
using FOp = CustomScalarOp;
762+
using IOp = CustomScalarOp;
763+
};
764+
765+
template <>
766+
GenOpMix getGenOpMix<ONNXCeluOp>(Type t, Operation *op) {
767+
return {{GenericOps::ArithmeticGop, 2}, {GenericOps::MulGop, 1},
768+
{GenericOps::MinMaxGop, 2}, {GenericOps::ExpGop, 1},
769+
{GenericOps::DivGop, 1}};
770+
}
771+
772+
template <>
773+
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
774+
Value emitScalarOpFor<ONNXCeluOp>(ConversionPatternRewriter &rewriter,
775+
Location loc, Operation *op, Type elementType,
776+
ArrayRef<Value> scalarOperands) {
777+
CheckIfCustomScalarOpIsSupported<ONNXCeluOp>(elementType);
778+
Value operand = scalarOperands[0];
779+
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
780+
781+
// Get the 'alpha' attribute from the Celu operation.
782+
auto celuOp = cast<ONNXCeluOp>(op);
783+
784+
double alphaValue = celuOp.getAlpha().convertToDouble();
785+
786+
// Create constants for 0, 1, and alpha.
787+
Value zero = create.math.constant(elementType, 0.0);
788+
Value one = create.math.constant(elementType, 1.0);
789+
Value alpha = create.math.constant(elementType, alphaValue);
790+
791+
// Compute positive part: max(0, x)
792+
Value positivePart = create.math.max(zero, operand);
793+
794+
// Compute negative part: alpha * (exp(x / alpha) - 1)
795+
Value xOverAlpha = create.math.div(operand, alpha);
796+
Value expVal = create.math.exp(xOverAlpha);
797+
Value expMinusOne = create.math.sub(expVal, one);
798+
Value scaled = create.math.mul(alpha, expMinusOne);
799+
800+
// Combine parts: positivePart + min(0, scaled)
801+
Value negativePart = create.math.min(zero, scaled);
802+
return create.math.add(positivePart, negativePart);
803+
}
804+
755805
//===----------------------------------------------------------------------===//
756806
// Scalar unary ops for lowering ONNXLeakyReluOp
757807
//===----------------------------------------------------------------------===//
@@ -785,7 +835,6 @@ Value emitScalarOpFor<ONNXLeakyReluOp>(ConversionPatternRewriter &rewriter,
785835
return create.math.select(
786836
lessThanZero, create.math.mul(alpha, operand), operand);
787837
}
788-
789838
//===----------------------------------------------------------------------===//
790839
// Scalar unary ops for lowering ONNXPReluOp
791840
//===----------------------------------------------------------------------===//
@@ -1756,15 +1805,16 @@ bool OpFusionHelper::checkFusibleOp(Operation *useOp, Operation *defOp,
17561805
// Unary Op
17571806
mlir::ONNXAbsOp, mlir::ONNXAtanOp, mlir::ONNXCastOp, mlir::ONNXCeilOp,
17581807
mlir::ONNXCosOp, mlir::ONNXCoshOp, mlir::ONNXDequantizeLinearOp,
1759-
mlir::ONNXEluOp, mlir::ONNXErfOp, mlir::ONNXAcosOp, mlir::ONNXAcoshOp,
1760-
mlir::ONNXAsinOp, mlir::ONNXAsinhOp, mlir::ONNXAtanhOp, mlir::ONNXExpOp,
1761-
mlir::ONNXFloorOp, mlir::ONNXGeluOp, mlir::ONNXHardSigmoidOp,
1762-
mlir::ONNXHardSwishOp, mlir::ONNXIsInfOp, mlir::ONNXIsNaNOp,
1763-
mlir::ONNXLeakyReluOp, mlir::ONNXLogOp, mlir::ONNXNegOp, mlir::ONNXNotOp,
1764-
mlir::ONNXReciprocalOp, mlir::ONNXReluOp, mlir::ONNXRoundOp,
1765-
mlir::ONNXSeluOp, mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp,
1766-
mlir::ONNXSinhOp, mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp,
1767-
mlir::ONNXSqrtOp, mlir::ONNXTanOp, mlir::ONNXTanhOp,
1808+
mlir::ONNXCeluOp, mlir::ONNXEluOp, mlir::ONNXErfOp, mlir::ONNXAcosOp,
1809+
mlir::ONNXAcoshOp, mlir::ONNXAsinOp, mlir::ONNXAsinhOp, mlir::ONNXAtanhOp,
1810+
mlir::ONNXExpOp, mlir::ONNXFloorOp, mlir::ONNXGeluOp,
1811+
mlir::ONNXHardSigmoidOp, mlir::ONNXHardSwishOp, mlir::ONNXIsInfOp,
1812+
mlir::ONNXIsNaNOp, mlir::ONNXLeakyReluOp, mlir::ONNXLogOp,
1813+
mlir::ONNXNegOp, mlir::ONNXNotOp, mlir::ONNXReciprocalOp,
1814+
mlir::ONNXReluOp, mlir::ONNXRoundOp, mlir::ONNXSeluOp,
1815+
mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp, mlir::ONNXSinhOp,
1816+
mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp, mlir::ONNXSqrtOp,
1817+
mlir::ONNXTanOp, mlir::ONNXTanhOp,
17681818
// Binary Op
17691819
mlir::ONNXEqualOp, mlir::ONNXGreaterOp, mlir::ONNXGreaterOrEqualOp,
17701820
mlir::ONNXLessOp, mlir::ONNXLessOrEqualOp, mlir::ONNXModOp,
@@ -2708,6 +2758,7 @@ void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
27082758
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseXorOp>,
27092759
ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
27102760
ONNXElementwiseUnaryOpLowering<mlir::ONNXCeilOp>,
2761+
ONNXElementwiseUnaryOpLowering<mlir::ONNXCeluOp>,
27112762
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
27122763
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
27132764
ONNXElementwiseUnaryOpLowering<mlir::ONNXDequantizeLinearOp>,

test/backend/inference_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,18 @@ def get_test_models():
484484
DYNAMIC_SHAPE: {-1: {-1}},
485485
CONSTANT_INPUT: {-1},
486486
},
487+
# ==OP== Celu
488+
# ==MIN== 12
489+
"test_celu_cpu": {
490+
STATIC_SHAPE: {},
491+
DYNAMIC_SHAPE: {-1: {-1}},
492+
CONSTANT_INPUT: {-1},
493+
},
494+
"test_celu_expanded_cpu": {
495+
STATIC_SHAPE: {},
496+
DYNAMIC_SHAPE: {-1: {-1}},
497+
CONSTANT_INPUT: {-1},
498+
},
487499
# ==OP== Clip
488500
# ==MIN== 6
489501
# ==LIM== No support for short integers

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,41 @@ func.func private @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
11841184

11851185
// -----
11861186

1187+
func.func private @test_celu(%arg0 : tensor<?x3x224x224xf32>) -> tensor<?x3x224x224xf32> {
1188+
%0 = "onnx.Celu"(%arg0) {alpha = 1.000000e+00 : f32} : (tensor<?x3x224x224xf32>) -> tensor<?x3x224x224xf32>
1189+
func.return %0 : tensor<?x3x224x224xf32>
1190+
1191+
// mlir2FileCheck.py
1192+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
1193+
// CHECK-LABEL: func.func private @test_celu
1194+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x3x224x224xf32>) -> memref<?x3x224x224xf32> {
1195+
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1.000000e+00 : f32
1196+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : f32
1197+
// CHECK-DAG: [[CST_IDX_0_:%.+]] = arith.constant 0 : index
1198+
// CHECK: [[DIM_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_IDX_0_]] : memref<?x3x224x224xf32>
1199+
// CHECK-DAG: [[ALLOC_:%.+]] = memref.alloc([[DIM_0_]]) {{.*}}: memref<?x3x224x224xf32>
1200+
// CHECK-DAG: [[LOOPS_:%.+]]:4 = krnl.define_loops 4
1201+
// CHECK-DAG: [[VAR_DIM_:%.+]] = memref.dim [[PARAM_0_]], [[CST_IDX_0_]] : memref<?x3x224x224xf32>
1202+
// CHECK: krnl.iterate([[LOOPS_]]#0, [[LOOPS_]]#1, [[LOOPS_]]#2, [[LOOPS_]]#3) with (
1203+
// CHECK-SAME: [[LOOPS_]]#0 -> [[I0_:%.+]] = 0 to [[MAP_0_]]([[VAR_DIM_]]),
1204+
// CHECK-SAME: [[LOOPS_]]#1 -> [[I1_:%.+]] = 0 to 3,
1205+
// CHECK-SAME: [[LOOPS_]]#2 -> [[I2_:%.+]] = 0 to 224,
1206+
// CHECK-SAME: [[LOOPS_]]#3 -> [[I3_:%.+]] = 0 to 224){
1207+
// CHECK: [[IVS_:%.+]]:4 = krnl.get_induction_var_value([[LOOPS_]]#0, [[LOOPS_]]#1, [[LOOPS_]]#2, [[LOOPS_]]#3)
1208+
// CHECK: [[LOAD_:%.+]] = krnl.load [[PARAM_0_]]{{.*}}[[IVS_]]#0, [[IVS_]]#1, [[IVS_]]#2, [[IVS_]]#3] : memref<?x3x224x224xf32>
1209+
// CHECK: [[MAX_:%.+]] = arith.maxnumf [[LOAD_]], [[CST_0_]] : f32
1210+
// CHECK: [[EXP_:%.+]] = math.exp [[LOAD_]] : f32
1211+
// CHECK: [[SUB_:%.+]] = arith.subf [[EXP_]], [[CST_1_]] : f32
1212+
// CHECK: [[MIN_:%.+]] = arith.minnumf [[SUB_]], [[CST_0_]] : f32
1213+
// CHECK: [[SUM_:%.+]] = arith.addf [[MAX_]], [[MIN_]] : f32
1214+
// CHECK: krnl.store [[SUM_]], [[ALLOC_]]{{.*}}[[IVS_]]#0, [[IVS_]]#1, [[IVS_]]#2, [[IVS_]]#3] : memref<?x3x224x224xf32>
1215+
// CHECK: }
1216+
// CHECK: return [[ALLOC_]] : memref<?x3x224x224xf32>
1217+
// CHECK: }
1218+
}
1219+
1220+
// -----
1221+
11871222
func.func private @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
11881223
%0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
11891224
"func.return"(%0) : (tensor<*xf32>) -> ()

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,50 @@ func.func private @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
16731673

16741674
// -----
16751675

1676+
func.func private @test_celu(%arg0 : tensor<?x3x224x224xf32>) -> tensor<?x3x224x224xf32> {
1677+
%0 = "onnx.Celu"(%arg0) {alpha = 1.000000e+00 : f32} : (tensor<?x3x224x224xf32>) -> tensor<?x3x224x224xf32>
1678+
func.return %0 : tensor<?x3x224x224xf32>
1679+
1680+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 150528)>
1681+
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)>
1682+
// CHECK-LABEL: func.func private @test_celu
1683+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x3x224x224xf32>) -> memref<?x3x224x224xf32> {
1684+
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32>
1685+
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
1686+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
1687+
// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x3x224x224xf32>
1688+
1689+
// CHECK: [[RES_ALLOC_:%.+]] = memref.alloc([[VAR_dim_]]) {alignment = 16 : i64} : memref<?x3x224x224xf32>
1690+
// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x3x224x224xf32>
1691+
// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_1_]]
1692+
// CHECK-DAG: [[RESHAPE_ALLOC_:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
1693+
// CHECK: affine.store [[VAR_0_]], [[RESHAPE_ALLOC_]][0] : memref<1xindex>
1694+
// CHECK-DAG: [[VAR_RESHAPE_:%.+]] = memref.reshape [[PARAM_0_]]([[RESHAPE_ALLOC_]]) : (memref<?x3x224x224xf32>, memref<1xindex>) -> memref<?xf32>
1695+
// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]
1696+
// CHECK-DAG: [[RESHAPE_ALLOC_2_:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
1697+
// CHECK: affine.store [[VAR_1_]], [[RESHAPE_ALLOC_2_]][0] : memref<1xindex>
1698+
// CHECK: [[VAR_RESHAPE_4_:%.+]] = memref.reshape [[RES_ALLOC_]]([[RESHAPE_ALLOC_2_]]) : (memref<?x3x224x224xf32>, memref<1xindex>) -> memref<?xf32>
1699+
// CHECK: krnl.iterate() with (){
1700+
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
1701+
// CHECK: [[BLOCK_TILE_0_:%.+]], [[BLOCK_IN_0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
1702+
// CHECK: krnl.iterate(%loop_block) with (%2 -> %arg1 = 0 to #map1()[%dim_1, %dim, %1]){
1703+
// CHECK: [[IV_0_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE_0_]]) : (!krnl.loop) -> index
1704+
// CHECK: [[VLOAD_:%.+]] = vector.load [[VAR_RESHAPE_:%.+]]{{\[}}[[IV_0_]]] : memref<?xf32>, vector<32xf32>
1705+
// CHECK: [[VMAX_:%.+]] = arith.maxnumf [[VLOAD_]], [[VAR_cst_0_]] : vector<32xf32>
1706+
// CHECK: [[VEXP_:%.+]] = math.exp [[VLOAD_]] : vector<32xf32>
1707+
// CHECK: [[VSUB_:%.+]] = arith.subf [[VEXP_]], [[VAR_cst_]] : vector<32xf32>
1708+
// CHECK: [[VMIN_:%.+]] = arith.minnumf [[VSUB_]], [[VAR_cst_0_]] : vector<32xf32>
1709+
// CHECK: [[VADD_:%.+]] = arith.addf [[VMAX_]], [[VMIN_]] : vector<32xf32>
1710+
// CHECK: vector.store [[VADD_]], [[VAR_RESHAPE_4_:%.+]]{{\[}}[[IV_0_]]] : memref<?xf32>, vector<32xf32>
1711+
// CHECK: }
1712+
// CHECK: }
1713+
// CHECK: return [[RES_ALLOC_]] : memref<?x3x224x224xf32>
1714+
1715+
}
1716+
1717+
1718+
1719+
// -----
16761720

16771721
func.func private @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
16781722
%0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>

0 commit comments

Comments
 (0)