Skip to content

Conversation

kumarappan-cmyk
Copy link
Contributor

This optimization fuses the pattern Split → Conv → Concat into a single Conv operation using grouped convolution when:

  • The input tensor is split along the channel axis.
  • Each split feeds into a separate convolution (Conv) layer.
  • The outputs of these convolution layers are concatenated along the output channel axis.

By converting multiple independent convolutions into a grouped convolution, we reduce memory usage, improve cache locality, and lower the number of kernel launches in backend runtimes.

Before optimization – Split+conv+concat
image

After Optimization – Conv with adjusted group
image

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Kumarappan <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Kumarappan <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Kumarappan <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema
Copy link
Contributor

@tungld Thanks for the review and verification, could you please approve and merge this patch!

@kumarappan-cmyk kumarappan-cmyk changed the title Converting Split->conv->concat to Grouped concat Converting Split->conv->concat to Grouped conv Apr 18, 2025
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema
Copy link
Contributor

@tungld thanks for the verification, could you please trigger the tests and merge the patch

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumarappan-cmyk thank you for the PR!

I put my 1st round of review. Please refactor the code to make it concise and please do use DialectBuilder.

For lit tests, could you put some tests that do not satisfy conditions for fusion, which makes sure that in such cases the recomposing pattern is not applied?

llvm::SmallVector<ONNXConvOp, 2> convOps;
ONNXConcatOp concatOp;

// Ensure the pattern exists: Split → Conv → Concat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put . to the end of all comments.

if (!weightType)
return failure();
int64_t rank = weightType.getRank();
if (1 >= rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank < 2 is easier to read?

return failure();
int64_t rank = weightType.getRank();
if (1 >= rank)
return failure(); // Ensure axis is within valid range
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not understand the comment. Is this check the same to the following check for axis?


// **Concatenating Conv Weights Correctly**
SmallVector<Value, 2> weightTensors;
int64_t total_C_out = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use camelCase instead of snake_case for naming.

}

// Create correct IntegerAttrs
IntegerAttr axis0 = rewriter.getI64IntegerAttr(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this value rewriter.getI64IntegerAttr(0); unused? since you replace it by axis0 = IntegerAttr::get(si64Type, 0); later.

biasTensors.push_back(conv.getB());

Type newBiasType =
RankedTensorType::get({total_C_out}, weightType.getElementType());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define Type elementType = weightType.getElementType(); at the beginning of this function and reuse it to avoid boilerplate code.


Type newBiasType =
RankedTensorType::get({total_C_out}, weightType.getElementType());
axis0 = IntegerAttr::get(si64Type, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define this at the beginning of the function also. Thanks!

// **Create new Grouped ConvOp**
auto newConv = rewriter.create<ONNXConvOp>(loc, resultType, input,
concatenatedWeight, hasBias ? concatenatedBias : Value(), autoPadAttr,
dilationsAttr, groupAttrVal, kernelShapeAttr, padsAttr, stridesAttr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concatenatedBias =
rewriter.create<ONNXConcatOp>(biasLoc, newBiasType, biasTensors,
axis0); // Bias should be concatenated along axis=0
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use DialectBuilder for concat.

// RUN: onnx-mlir-opt --recompose-onnx --remove-dead-values --constprop-onnx %s -split-input-file | FileCheck %s

func.func @simple_split_conv_concat(%arg0: tensor<1x6x512x512xf64> {onnx.name = "input"}) -> (tensor<1x6x512x512xf64> {onnx.name = "output"}) {
%0 = onnx.Constant dense<[[[[-0.0017646604683250189, 0.12644097208976746, -0.19399359822273254], [-0.17346249520778656, -0.090781755745410919, 0.0632052943110466], [-0.0046700113452970982, 0.18688584864139557, -0.020917171612381935]], [[0.062369778752326965, -0.071232303977012634, -0.046330906450748444], [-0.22517779469490051, -0.15610139071941376, -0.097161918878555298], [0.008731253445148468, 0.093181401491165161, 0.14142672717571259]]], [[[-0.15979224443435669, -0.1026395708322525, 0.085611097514629364], [0.19572432339191437, -0.048507567495107651, 0.1763787716627121], [-0.037991281598806381, 0.024940622970461845, 0.21342279016971588]], [[-0.21865400671958923, -0.14838351309299469, -0.059671621769666672], [-0.09187673032283783, 0.2036469429731369, -0.15277740359306335], [-0.10850150138139725, -0.16467113792896271, -0.22074954211711884]]]]> : tensor<2x2x3x3xf64>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use splat constants for better reading since the values here are not the enssential part of the test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants