Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,45 @@ bool isNotConvProducer(mlir::Value val) {
return true; // If no defining op, assume it's safe
}

bool isTransBFalse(mlir::Attribute attr) {
if (auto intAttr = attr.dyn_cast<mlir::IntegerAttr>()) {
int64_t val =
intAttr.getValue().getSExtValue(); // safe for signless integers
return val == 0; // return true if transB is false (0)
}
return false; // default fallback
}

bool isZeroTensorOrSplat(Value val) {
if (auto constOp = val.getDefiningOp<ONNXConstantOp>()) {
auto attrOpt = constOp.getValue();
if (attrOpt.has_value()) {
if (auto dense = mlir::dyn_cast<DenseElementsAttr>(*attrOpt))
return dense.isSplat() && dense.getSplatValue<APFloat>().isZero();
}
}
return false;
}

bool isOneTensorOrSplat(Value val) {
if (auto constOp = val.getDefiningOp<ONNXConstantOp>()) {
auto attrOpt = constOp.getValue();
if (attrOpt.has_value()) {
if (auto dense = mlir::dyn_cast<DenseElementsAttr>(*attrOpt)) {
if (dense.isSplat())
return dense.getSplatValue<APFloat>().convertToDouble() == 1.0;
}
}
}
return false;
}

bool isZeroAttrOrZeroTensor(Attribute attr) {
if (auto floatAttr = mlir::dyn_cast<FloatAttr>(attr))
return floatAttr.getValue().isZero();
return false;
}

// Get the index of the axis value in the given permutation array.
IntegerAttr getIndexOfAxisInPerm(
PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) {
Expand Down Expand Up @@ -1985,6 +2024,7 @@ void ONNXBatchNormalizationInferenceModeOp::getCanonicalizationPatterns(
results.insert<FuseBatchNormInferenceModeConvPattern>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern1>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern2>(context);
results.insert<BackwardFoldScaleAxisToGemmPattern>(context);
}

/// on the ONNXAddOp.
Expand Down
74 changes: 74 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,80 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat<
[(HasRankOf<1> $x)], [], (addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// This optimization folds the composition: 'BatchNormalization o Gemm' into 'Gemm'
// by recomputing new 'B' and 'C' parameters for the Gemm operation by fusing
// the BatchNormalization's scale and bias directly into them.
//
// Given:
// (Gemm) Z = A * B + C
// (BatchNormalization in inference mode)
// Y = scale * (Z - mean) / sqrt(var + epsilon) + bias
//
// In inference mode, when mean=0, var=1, and epsilon=0, the BatchNormalization
// simplifies to:
// Y = scale * Z + bias
//
// This allows us to recompute:
// Y = A * (scale * B) + (scale * C + bias)
//
// Therefore, we rewrite:
// onnx.BatchNormalizationInferenceMode(
// onnx.Gemm(A, B, C, alpha, beta, transA, transB),
// scale, bias, mean, var
// ) {epsilon = ..., momentum = ...}
//
// as:
// onnx.Gemm(
// A,
// onnx.Mul(B, scale),
// onnx.Add(onnx.Mul(C, scale), bias),
// alpha, beta, transA, transB)
//
// This transformation is only valid when:
// - transB = 0 (to maintain correct shape alignment)
// - mean is 0
// - var is 1
// - epsilon is 0
//
//===----------------------------------------------------------------------===//

def isTransBFalse : Constraint<CPred<
"onnx_mlir::isTransBFalse($0)">, "TransB is 1 not 0"
>;

def meanIsZero : Constraint<
CPred<"onnx_mlir::isZeroTensorOrSplat($0)">, "mean must be 0"
>;

def varIsOne : Constraint<
CPred<"onnx_mlir::isOneTensorOrSplat($0)">, "var must be 1"
>;

def epsIsZero : Constraint<
CPred<"onnx_mlir::isZeroAttrOrZeroTensor($0)">, "epsilon must be 0"
>;

def BackwardFoldScaleAxisToGemmPattern : Pat<
(ONNXBatchNormalizationInferenceModeOp:$res
(ONNXGemmOp $A, $B, $C, $alpha, $beta, $transA, $transB),
$scale, $bias, $_mean, $_var, $_epsilon, $_momentum),
(ONNXGemmOp
$A,
(ONNXMulOp $B, $scale),
(ONNXAddOp
(ONNXMulOp $C, $scale),
$bias),
(GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[
(isTransBFalse $transB),
(meanIsZero $_mean),
(varIsOne $_var),
(epsIsZero $_epsilon)
],
[], (addBenefit 1)
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXShapeOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,32 @@ func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale

// -----

func.func @test_backward_fold_scale_axis(%arg0: tensor<1x256xf32>) -> tensor<1x128xf32> {
%0 = onnx.Constant dense<0.00999999977> : tensor<256x128xf32>
%1 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%2 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%3 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%4 = onnx.Constant dense<0.0> : tensor<128xf32>
%5 = onnx.Constant dense<1.0> : tensor<128xf32>
%6 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "onnx.Gemm_0", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 0.0 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%8 = "onnx.Relu"(%7) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %8 : tensor<1x128xf32>
// CHECK-LABEL: func @test_backward_fold_scale_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256xf32>) -> tensor<1x128xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<256x128xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<128xf32>
// CHECK: [[MUL_0_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_1_]])
// CHECK: [[MUL_1_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]])
// CHECK: [[ADD_0_:%.+]] = "onnx.Add"([[MUL_1_]], [[VAR_1_]])
// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[MUL_0_]], [[ADD_0_]])
// CHECK-SAME: : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
// CHECK-NEXT: return [[VAR_3_]] : tensor<1x128xf32>
}

// -----

func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32>
Expand Down
Loading