Skip to content

Conversation

Arkar-Hema
Copy link
Contributor

BackwardFoldScaleAxisToGemm

BackwardFoldScaleAxisToGemm is an optimization pass in ONNX-MLIR that merges a BatchNormalization layer into a preceding Gemm (General Matrix Multiplication) layer when specific conditions are met.
Goal: To simplify the model by reducing redundant computations and improving runtime efficiency by statically folding the BatchNormalization into the Gemm operation.

Conditions:

  • The Gemm op must have transB = 0 (i.e., the weight matrix is used without transposition)
  • The BatchNormalization must be in inference mode

Original computation:
Output=BatchNorm(Gemm(X,W,B))

After pass computation:
NewOutput=Gemm(X,W×γ,B×γ+β)
where,

  • X = Input to Gemm
  • W = Weights in Gemm
  • B = Bias in Gemm
  • γ = Scale in BatchNormalization
  • β = Bias in BatchNormalization

@Arkar-Hema
Copy link
Contributor Author

Can any one of the admins verify this patch?

// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias
//
// This transformation corresponds to a recomposition:
// Y = A * (scale * B) + (scale * bias + C)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Arkar-Hema could you elaborate how you derived this formula where mean, var and eps are canceled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):

Y ≈ scale × Z + bias
Substituting Z = A x B + C
Y ≈ scale × (A × B + C) + bias
Y = A × (scale × B) + (scale × C + bias)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):

Then, you have to define this assumption in the constraint part of the rewriting rule. Otherwise, the rewriting rule produces a wrong result.

Anyway, my recommendation is to handle the general case where mean, var, eps are constants (not necessary concrete values, say, of 0, 1 and 0, respectively). New scale and bias values for matmul can be easily computed from mean, var, eps, scale and bias of BatchNorm, and in the inference mode, these values are constants and will be folded automatically by the compiler into a single constant.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants