@@ -752,6 +752,56 @@ Value emitScalarOpFor<ONNXReluOp>(ConversionPatternRewriter &rewriter,
752752 return create.math .max (zero, operand);
753753}
754754
755+ // ===----------------------------------------------------------------------===//
756+ // Scalar unary ops for lowering ONNXCeLUOp
757+ // ===----------------------------------------------------------------------===//
758+
759+ template <>
760+ struct ScalarOp <ONNXCeluOp> {
761+ using FOp = CustomScalarOp;
762+ using IOp = CustomScalarOp;
763+ };
764+
765+ template <>
766+ GenOpMix getGenOpMix<ONNXCeluOp>(Type t, Operation *op) {
767+ return {{GenericOps::ArithmeticGop, 2 }, {GenericOps::MulGop, 1 },
768+ {GenericOps::MinMaxGop, 2 }, {GenericOps::ExpGop, 1 },
769+ {GenericOps::DivGop, 1 }};
770+ }
771+
772+ template <>
773+ // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
774+ Value emitScalarOpFor<ONNXCeluOp>(ConversionPatternRewriter &rewriter,
775+ Location loc, Operation *op, Type elementType,
776+ ArrayRef<Value> scalarOperands) {
777+ CheckIfCustomScalarOpIsSupported<ONNXCeluOp>(elementType);
778+ Value operand = scalarOperands[0 ];
779+ MultiDialectBuilder<MathBuilder> create (rewriter, loc);
780+
781+ // Get the 'alpha' attribute from the Celu operation.
782+ auto celuOp = cast<ONNXCeluOp>(op);
783+
784+ double alphaValue = celuOp.getAlpha ().convertToDouble ();
785+
786+ // Create constants for 0, 1, and alpha.
787+ Value zero = create.math .constant (elementType, 0.0 );
788+ Value one = create.math .constant (elementType, 1.0 );
789+ Value alpha = create.math .constant (elementType, alphaValue);
790+
791+ // Compute positive part: max(0, x)
792+ Value positivePart = create.math .max (zero, operand);
793+
794+ // Compute negative part: alpha * (exp(x / alpha) - 1)
795+ Value xOverAlpha = create.math .div (operand, alpha);
796+ Value expVal = create.math .exp (xOverAlpha);
797+ Value expMinusOne = create.math .sub (expVal, one);
798+ Value scaled = create.math .mul (alpha, expMinusOne);
799+
800+ // Combine parts: positivePart + min(0, scaled)
801+ Value negativePart = create.math .min (zero, scaled);
802+ return create.math .add (positivePart, negativePart);
803+ }
804+
755805// ===----------------------------------------------------------------------===//
756806// Scalar unary ops for lowering ONNXLeakyReluOp
757807// ===----------------------------------------------------------------------===//
@@ -785,7 +835,6 @@ Value emitScalarOpFor<ONNXLeakyReluOp>(ConversionPatternRewriter &rewriter,
785835 return create.math .select (
786836 lessThanZero, create.math .mul (alpha, operand), operand);
787837}
788-
789838// ===----------------------------------------------------------------------===//
790839// Scalar unary ops for lowering ONNXPReluOp
791840// ===----------------------------------------------------------------------===//
@@ -1756,15 +1805,16 @@ bool OpFusionHelper::checkFusibleOp(Operation *useOp, Operation *defOp,
17561805 // Unary Op
17571806 mlir::ONNXAbsOp, mlir::ONNXAtanOp, mlir::ONNXCastOp, mlir::ONNXCeilOp,
17581807 mlir::ONNXCosOp, mlir::ONNXCoshOp, mlir::ONNXDequantizeLinearOp,
1759- mlir::ONNXEluOp, mlir::ONNXErfOp, mlir::ONNXAcosOp, mlir::ONNXAcoshOp,
1760- mlir::ONNXAsinOp, mlir::ONNXAsinhOp, mlir::ONNXAtanhOp, mlir::ONNXExpOp,
1761- mlir::ONNXFloorOp, mlir::ONNXGeluOp, mlir::ONNXHardSigmoidOp,
1762- mlir::ONNXHardSwishOp, mlir::ONNXIsInfOp, mlir::ONNXIsNaNOp,
1763- mlir::ONNXLeakyReluOp, mlir::ONNXLogOp, mlir::ONNXNegOp, mlir::ONNXNotOp,
1764- mlir::ONNXReciprocalOp, mlir::ONNXReluOp, mlir::ONNXRoundOp,
1765- mlir::ONNXSeluOp, mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp,
1766- mlir::ONNXSinhOp, mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp,
1767- mlir::ONNXSqrtOp, mlir::ONNXTanOp, mlir::ONNXTanhOp,
1808+ mlir::ONNXCeluOp, mlir::ONNXEluOp, mlir::ONNXErfOp, mlir::ONNXAcosOp,
1809+ mlir::ONNXAcoshOp, mlir::ONNXAsinOp, mlir::ONNXAsinhOp, mlir::ONNXAtanhOp,
1810+ mlir::ONNXExpOp, mlir::ONNXFloorOp, mlir::ONNXGeluOp,
1811+ mlir::ONNXHardSigmoidOp, mlir::ONNXHardSwishOp, mlir::ONNXIsInfOp,
1812+ mlir::ONNXIsNaNOp, mlir::ONNXLeakyReluOp, mlir::ONNXLogOp,
1813+ mlir::ONNXNegOp, mlir::ONNXNotOp, mlir::ONNXReciprocalOp,
1814+ mlir::ONNXReluOp, mlir::ONNXRoundOp, mlir::ONNXSeluOp,
1815+ mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp, mlir::ONNXSinhOp,
1816+ mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp, mlir::ONNXSqrtOp,
1817+ mlir::ONNXTanOp, mlir::ONNXTanhOp,
17681818 // Binary Op
17691819 mlir::ONNXEqualOp, mlir::ONNXGreaterOp, mlir::ONNXGreaterOrEqualOp,
17701820 mlir::ONNXLessOp, mlir::ONNXLessOrEqualOp, mlir::ONNXModOp,
@@ -2708,6 +2758,7 @@ void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
27082758 ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseXorOp>,
27092759 ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
27102760 ONNXElementwiseUnaryOpLowering<mlir::ONNXCeilOp>,
2761+ ONNXElementwiseUnaryOpLowering<mlir::ONNXCeluOp>,
27112762 ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
27122763 ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
27132764 ONNXElementwiseUnaryOpLowering<mlir::ONNXDequantizeLinearOp>,
0 commit comments