-
Notifications
You must be signed in to change notification settings - Fork 49
Add proper CPU lowering for tosa::transpose_conv2d
#2007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements proper CPU lowering for tosa::transpose_conv2d
operations by adding support for dilation and input padding parameters. The changes address a previously broken CPU lowering pipeline for backward data convolution operations.
Key changes:
- Enhanced
TosaDecomposeTransposeConv.cpp
to handle dilation and input padding in transpose convolution lowering - Updated verification logic in
TosaOps.cpp
to use unified formulas that account for optional dilation and input padding parameters - Refactored test files to use the new lowering path with proper CPU verification instead of workarounds
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir | Updated to use new lowering path with CPU verification |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir | New test case for stride configuration |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir | Updated to use new lowering path |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir | New test case for padding configuration |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir | New test case for dilation configuration |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir | Updated to use new lowering path |
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir | Updated to use new lowering path |
external/llvm-project/mlir/test/Dialect/Tosa/invalid.mlir | Updated error messages to reflect new verification formulas |
external/llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp | Main implementation of dilation and input padding support |
external/llvm-project/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | Updated verification logic with unified formulas |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
bool needSlice = false; | ||
SmallVector<int64_t,4> negExcess(4,0); | ||
for (int i=0;i<4;++i) { | ||
if (convPad[i] < 0) { | ||
negExcess[i] = -convPad[i]; | ||
convPad[i] = 0; | ||
needSlice = true; | ||
} | ||
} |
Copilot
AI
Sep 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing spaces around operators in the for loop condition. Should be for (int i = 0; i < 4; ++i)
.
Copilot uses AI. Check for mistakes.
}; | ||
|
||
bool needSlice = false; | ||
SmallVector<int64_t,4> negExcess(4,0); |
Copilot
AI
Sep 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after comma in template parameters. Should be SmallVector<int64_t, 4>
.
SmallVector<int64_t,4> negExcess(4,0); | |
SmallVector<int64_t, 4> negExcess(4,0); |
Copilot uses AI. Check for mistakes.
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), | ||
/* acc_type = */ op.getAccType(), | ||
op->getAttrOfType<IntegerAttr>("group")) | ||
op.getAccType(), op->getAttrOfType<IntegerAttr>("group")) |
Copilot
AI
Sep 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Inconsistent parameter formatting compared to the original code. The original had explicit parameter names in comments which improved readability.
Copilot uses AI. Check for mistakes.
719d532
to
b173021
Compare
Note for reviewers, once I've gotten the necessary approvals I will separate this into two different commits. One for the external llvm-project changes, and another for the rocMLIR changes. |
if (auto kindAttr = op->getAttrOfType<StringAttr>("conv_kind"); | ||
kindAttr && kindAttr.getValue() == "bwd_data") { | ||
// Expected current shape: [K, H, W, C] but Conv2D expects [O, H, W, I] | ||
// Swap K<->C => permutation {3,1,2,0}. | ||
auto wShape = weightTy.getShape(); | ||
SmallVector<int64_t, 4> swappedShape{ | ||
wShape[3], // C becomes O | ||
wShape[1], | ||
wShape[2], | ||
wShape[0] // K becomes I | ||
}; | ||
auto swappedTy = | ||
RankedTensorType::get(swappedShape, weightTy.getElementType()); | ||
weight = rewriter.create<tosa::TransposeOp>( | ||
loc, swappedTy, weight, rewriter.getDenseI32ArrayAttr({3, 1, 2, 0})); | ||
weightTy = cast<ShapedType>(weight.getType()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make more sense to do this in MIGraphXToTosa, and then again when converting back from TosaToRock?
b173021
to
b5b0f1c
Compare
do we want to have an upstream PR for this? they'll probably say it's not part of the tosa specs, but maybe we can open a discussion to include this? |
They explicitly had PRs removing dilation attributes from the code a couple of years ago. It still may be worth a try. Do you know what forum would be best to go about asking for changes to the TOSA spec? |
In // Set attributes common to both forwards and backwards conv
cop->setAttr("dilation", rewriter.getDenseI64ArrayAttr(dilations));
cop->setAttr("stride", rewriter.getDenseI64ArrayAttr(strides));
cop->setAttr("acc_type", TypeAttr::get(accType)); But I don't see anywhere in this PR where those attributes are used. Is that intentional? |
They are both used in |
Motivation
This PR implements proper CPU lowering for
tosa::transpose_conv2d
operations by adding support for dilation and input padding attributes. This change means that we can now remove the LIT test workaround in the bwd_data_conv e2e tests that were introduced in: #1951Implements: https://github.com/ROCm/rocMLIR-internal/issues/1990
Technical Details
This PR implements the following key changes:
TransposeConvNonStridedConverter
andTransposeConvStridedConverter
to both handle input padding and dilation attributes (see the comments in the new logic for more details on this)Test Plan
Test Result
Submission Checklist