Skip to content

Conversation

justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Sep 30, 2025

Motivation

This PR implements proper CPU lowering for tosa::transpose_conv2d operations by adding support for dilation and input padding attributes. This change means that we can now remove the LIT test workaround in the bwd_data_conv e2e tests that were introduced in: #1951

Implements: https://github.com/ROCm/rocMLIR-internal/issues/1990

Technical Details

This PR implements the following key changes:

  • Updating the verification function in TosaOps.cpp to use a unified approach (merging the legacy and our updated formulas)
  • Update TransposeConvNonStridedConverter and TransposeConvStridedConverter to both handle input padding and dilation attributes (see the comments in the new logic for more details on this)
  • Updated the bwd_data_conv e2e tests that no longer need a workaround.

Test Plan

  • Update existing bwd_conv e2e tests to use the new lowering path
  • Add new e2e tests for more edge cases

Test Result

  • Nightly CI (with MIGraphX tests enabled)
  • All bwd_conv e2e LIT tests are passing

Submission Checklist

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements proper CPU lowering for tosa::transpose_conv2d operations by adding support for dilation and input padding parameters. The changes address a previously broken CPU lowering pipeline for backward data convolution operations.

Key changes:

  • Enhanced TosaDecomposeTransposeConv.cpp to handle dilation and input padding in transpose convolution lowering
  • Updated verification logic in TosaOps.cpp to use unified formulas that account for optional dilation and input padding parameters
  • Refactored test files to use the new lowering path with proper CPU verification instead of workarounds

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir Updated to use new lowering path with CPU verification
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir New test case for stride configuration
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir Updated to use new lowering path
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir New test case for padding configuration
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir New test case for dilation configuration
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir Updated to use new lowering path
mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir Updated to use new lowering path
external/llvm-project/mlir/test/Dialect/Tosa/invalid.mlir Updated error messages to reflect new verification formulas
external/llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp Main implementation of dilation and input padding support
external/llvm-project/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp Updated verification logic with unified formulas

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +91 to +97
bool needSlice = false;
SmallVector<int64_t,4> negExcess(4,0);
for (int i=0;i<4;++i) {
if (convPad[i] < 0) {
negExcess[i] = -convPad[i];
convPad[i] = 0;
needSlice = true;
}
}
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing spaces around operators in the for loop condition. Should be for (int i = 0; i < 4; ++i).

Copilot uses AI. Check for mistakes.

};

bool needSlice = false;
SmallVector<int64_t,4> negExcess(4,0);
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in template parameters. Should be SmallVector<int64_t, 4>.

Suggested change
SmallVector<int64_t,4> negExcess(4,0);
SmallVector<int64_t, 4> negExcess(4,0);

Copilot uses AI. Check for mistakes.

/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/* acc_type = */ op.getAccType(),
op->getAttrOfType<IntegerAttr>("group"))
op.getAccType(), op->getAttrOfType<IntegerAttr>("group"))
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Inconsistent parameter formatting compared to the original code. The original had explicit parameter names in comments which improved readability.

Copilot uses AI. Check for mistakes.

@justinrosner justinrosner force-pushed the justinr-cpu-transpose-conv2d branch from 719d532 to b173021 Compare September 30, 2025 20:27
@justinrosner
Copy link
Contributor Author

Note for reviewers, once I've gotten the necessary approvals I will separate this into two different commits. One for the external llvm-project changes, and another for the rocMLIR changes.

Comment on lines 148 to 165
if (auto kindAttr = op->getAttrOfType<StringAttr>("conv_kind");
kindAttr && kindAttr.getValue() == "bwd_data") {
// Expected current shape: [K, H, W, C] but Conv2D expects [O, H, W, I]
// Swap K<->C => permutation {3,1,2,0}.
auto wShape = weightTy.getShape();
SmallVector<int64_t, 4> swappedShape{
wShape[3], // C becomes O
wShape[1],
wShape[2],
wShape[0] // K becomes I
};
auto swappedTy =
RankedTensorType::get(swappedShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(
loc, swappedTy, weight, rewriter.getDenseI32ArrayAttr({3, 1, 2, 0}));
weightTy = cast<ShapedType>(weight.getType());
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to do this in MIGraphXToTosa, and then again when converting back from TosaToRock?

@justinrosner justinrosner force-pushed the justinr-cpu-transpose-conv2d branch from b173021 to b5b0f1c Compare September 30, 2025 21:00
@justinrosner justinrosner marked this pull request as ready for review October 1, 2025 12:51
@justinrosner justinrosner requested a review from causten as a code owner October 1, 2025 12:51
@dhernandez0
Copy link
Contributor

do we want to have an upstream PR for this? they'll probably say it's not part of the tosa specs, but maybe we can open a discussion to include this?

@justinrosner
Copy link
Contributor Author

do we want to have an upstream PR for this? they'll probably say it's not part of the tosa specs, but maybe we can open a discussion to include this?

They explicitly had PRs removing dilation attributes from the code a couple of years ago. It still may be worth a try. Do you know what forum would be best to go about asking for changes to the TOSA spec?

@pabloantoniom
Copy link
Contributor

In MIGraphXToTosa we do:

// Set attributes common to both forwards and backwards conv
cop->setAttr("dilation", rewriter.getDenseI64ArrayAttr(dilations));
cop->setAttr("stride", rewriter.getDenseI64ArrayAttr(strides));
cop->setAttr("acc_type", TypeAttr::get(accType));

But I don't see anywhere in this PR where those attributes are used. Is that intentional?

@justinrosner
Copy link
Contributor Author

In MIGraphXToTosa we do:

// Set attributes common to both forwards and backwards conv
cop->setAttr("dilation", rewriter.getDenseI64ArrayAttr(dilations));
cop->setAttr("stride", rewriter.getDenseI64ArrayAttr(strides));
cop->setAttr("acc_type", TypeAttr::get(accType));

But I don't see anywhere in this PR where those attributes are used. Is that intentional?

They are both used in TosaDecomposeTransposeConv. For the simple, non-strided case, we pass dilation through and use the values directly on the forwards conv op that is created. And for the strided case we use both the strides and dilation in the calculation. Also acc_type is being set for both.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants