-
Notifications
You must be signed in to change notification settings - Fork 367
Converting Split->conv->concat to Grouped conv #3124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Converting Split->conv->concat to Grouped conv #3124
Conversation
Signed-off-by: Kumarappan <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Kumarappan <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Kumarappan <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Kumarappan <[email protected]>
Can one of the admins verify this patch? |
@tungld Thanks for the review and verification, could you please approve and merge this patch! |
Can one of the admins verify this patch? |
@tungld thanks for the verification, could you please trigger the tests and merge the patch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kumarappan-cmyk thank you for the PR!
I put my 1st round of review. Please refactor the code to make it concise and please do use DialectBuilder.
For lit tests, could you put some tests that do not satisfy conditions for fusion, which makes sure that in such cases the recomposing pattern is not applied?
llvm::SmallVector<ONNXConvOp, 2> convOps; | ||
ONNXConcatOp concatOp; | ||
|
||
// Ensure the pattern exists: Split → Conv → Concat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put .
to the end of all comments.
if (!weightType) | ||
return failure(); | ||
int64_t rank = weightType.getRank(); | ||
if (1 >= rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank < 2
is easier to read?
return failure(); | ||
int64_t rank = weightType.getRank(); | ||
if (1 >= rank) | ||
return failure(); // Ensure axis is within valid range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not understand the comment. Is this check the same to the following check for axis?
|
||
// **Concatenating Conv Weights Correctly** | ||
SmallVector<Value, 2> weightTensors; | ||
int64_t total_C_out = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use camelCase instead of snake_case for naming.
} | ||
|
||
// Create correct IntegerAttrs | ||
IntegerAttr axis0 = rewriter.getI64IntegerAttr(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this value rewriter.getI64IntegerAttr(0);
unused? since you replace it by axis0 = IntegerAttr::get(si64Type, 0);
later.
biasTensors.push_back(conv.getB()); | ||
|
||
Type newBiasType = | ||
RankedTensorType::get({total_C_out}, weightType.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define Type elementType = weightType.getElementType();
at the beginning of this function and reuse it to avoid boilerplate code.
|
||
Type newBiasType = | ||
RankedTensorType::get({total_C_out}, weightType.getElementType()); | ||
axis0 = IntegerAttr::get(si64Type, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define this at the beginning of the function also. Thanks!
// **Create new Grouped ConvOp** | ||
auto newConv = rewriter.create<ONNXConvOp>(loc, resultType, input, | ||
concatenatedWeight, hasBias ? concatenatedBias : Value(), autoPadAttr, | ||
dilationsAttr, groupAttrVal, kernelShapeAttr, padsAttr, stridesAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use DialectBuilder for conv, c.f. https://github.com/onnx/onnx-mlir/blob/main/src/Dialect/ONNX/DialectBuilder.hpp#L76
concatenatedBias = | ||
rewriter.create<ONNXConcatOp>(biasLoc, newBiasType, biasTensors, | ||
axis0); // Bias should be concatenated along axis=0 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use DialectBuilder for concat.
// RUN: onnx-mlir-opt --recompose-onnx --remove-dead-values --constprop-onnx %s -split-input-file | FileCheck %s | ||
|
||
func.func @simple_split_conv_concat(%arg0: tensor<1x6x512x512xf64> {onnx.name = "input"}) -> (tensor<1x6x512x512xf64> {onnx.name = "output"}) { | ||
%0 = onnx.Constant dense<[[[[-0.0017646604683250189, 0.12644097208976746, -0.19399359822273254], [-0.17346249520778656, -0.090781755745410919, 0.0632052943110466], [-0.0046700113452970982, 0.18688584864139557, -0.020917171612381935]], [[0.062369778752326965, -0.071232303977012634, -0.046330906450748444], [-0.22517779469490051, -0.15610139071941376, -0.097161918878555298], [0.008731253445148468, 0.093181401491165161, 0.14142672717571259]]], [[[-0.15979224443435669, -0.1026395708322525, 0.085611097514629364], [0.19572432339191437, -0.048507567495107651, 0.1763787716627121], [-0.037991281598806381, 0.024940622970461845, 0.21342279016971588]], [[-0.21865400671958923, -0.14838351309299469, -0.059671621769666672], [-0.09187673032283783, 0.2036469429731369, -0.15277740359306335], [-0.10850150138139725, -0.16467113792896271, -0.22074954211711884]]]]> : tensor<2x2x3x3xf64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use splat constants for better reading since the values here are not the enssential part of the test.
This optimization fuses the pattern Split → Conv → Concat into a single Conv operation using grouped convolution when:
By converting multiple independent convolutions into a grouped convolution, we reduce memory usage, improve cache locality, and lower the number of kernel launches in backend runtimes.
Before optimization – Split+conv+concat

After Optimization – Conv with adjusted group
