-
Notifications
You must be signed in to change notification settings - Fork 367
Backward fold scale axis to Gemm layer in ONNX Dialect #3131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can any one of the admins verify this patch? |
// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias | ||
// | ||
// This transformation corresponds to a recomposition: | ||
// Y = A * (scale * B) + (scale * bias + C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Arkar-Hema could you elaborate how you derived this formula where mean
, var
and eps
are canceled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):
Y ≈ scale × Z + bias
Substituting Z = A x B + C
Y ≈ scale × (A × B + C) + bias
Y = A × (scale × B) + (scale × C + bias)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):
Then, you have to define this assumption in the constraint part of the rewriting rule. Otherwise, the rewriting rule produces a wrong result.
Anyway, my recommendation is to handle the general case where mean
, var
, eps
are constants (not necessary concrete values, say, of 0, 1 and 0, respectively). New scale and bias values for matmul can be easily computed from mean
, var
, eps
, scale
and bias
of BatchNorm, and in the inference mode, these values are constants and will be folded automatically by the compiler into a single constant.
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
BackwardFoldScaleAxisToGemm
BackwardFoldScaleAxisToGemm is an optimization pass in ONNX-MLIR that merges a BatchNormalization layer into a preceding Gemm (General Matrix Multiplication) layer when specific conditions are met.
Goal: To simplify the model by reducing redundant computations and improving runtime efficiency by statically folding the BatchNormalization into the Gemm operation.
Conditions:
Original computation:
Output=BatchNorm(Gemm(X,W,B))
After pass computation:
NewOutput=Gemm(X,W×γ,B×γ+β)
where,