Skip to content

Conversation

pabloantoniom
Copy link
Contributor

@pabloantoniom pabloantoniom commented Oct 7, 2025

Motivation

We want to add support for Conv3d lowering. However, there is no tosa.transpose_conv3d upstream (only tosa.transpose_conv2D). Therefore, we decided that the solution is to do this via tosa::CustomOp. On GPU it will be lowered to rock (easy) and on CPU it will need to be lowered to tosa.reverse + tosa.conv3d.

Because currently we are lowering both transpose Conv1d and Conv2d using tosa.transpose_conv2D, we want to change the current approach for 1D/2D first, so that when we introduce 3D, all transpose convs are lowered using the same path. Therefore, the current PR is in preparation of a future PR.

Technical Details

  • In MIGraphXToTosa.cpp we now use tosa::CustomOp to represent backward convolutions
  • To lower the CustomOp representing the backward convolution:
    • CPU: A new pass called rocmlir-custom-tosa-decompose pass is introduced (which is based on the upstream pass).
    • GPU: In TosaToRock.cpp, ConvConverter is renamed to ForwardConvConverter (which now handles forward conv only), and a new BackwardConvConverter is added, which lowers the CustomOp representing backward convolutions. I was forced to make two separate passes because the original one was incompatible with both Conv2dOp and CustomOp ops.

Test Plan

  • Modified migraphx-to-tosa.mlir to check for custom rather than transpose_conv2d
  • Added 2 new tests to tosa-to-rock.mlir to exercise backward convolution (we were missing these before...?)

Test Result

Test passes

Submission Checklist

@pabloantoniom pabloantoniom requested a review from causten as a code owner October 7, 2025 09:43
@pabloantoniom pabloantoniom changed the title [DRAFT] Lower convolutions using tosa::CustomOp [DRAFT] Lower 1D and 2D convolutions using tosa::CustomOp Oct 9, 2025
@pabloantoniom pabloantoniom changed the title [DRAFT] Lower 1D and 2D convolutions using tosa::CustomOp [DRAFT] Lower 1D and 2D backward convolution ops using tosa::CustomOp Oct 9, 2025
@pabloantoniom pabloantoniom changed the title [DRAFT] Lower 1D and 2D backward convolution ops using tosa::CustomOp Lower 1D and 2D backward convolution ops using tosa::CustomOp Oct 9, 2025
}
}

template <typename OpT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's only for forward conv, can't we use ConvOp here?

Copy link
Contributor Author

@pabloantoniom pabloantoniom Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It accepts Conv2DOp and Conv3DOp, which are different ops

auto operands = adaptor.getOperands();
auto loc = op->getLoc();
auto *context = op->getContext();
auto input = operands[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought depending on the conv, input and filter might be switched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know...which type of conv should have the inputs switched?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward convs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. This function was handling backward convolution 2D before and it does not seem to be switching the input and filter. Maybe it was wrong originally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have functions like this that would handle this discrepancy:

OpOperand *ConvBwdDataOp::getOutArgument() { return &(*this)->getOpOperand(1); }

void runOnOperation() override {
auto func = getOperation();
if (!func->hasAttr("kernel")) {
LDBG() << "func has no kernel attribute, skipping pass";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between this and LLVM_DEBUG(llvm::dbgs()?

Copy link
Contributor Author

@pabloantoniom pabloantoniom Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt actually want to commit this code 😅

return failure();

Value result;
Operation *rockConvOp = rockConv->getOperation();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, ConvOp?

auto bias = operands[2];
RankedTensorType outputType = cast<RankedTensorType>(op.getType(0));

rock::GemmFeatures features = getGemmFeaturesFromOp(op, input.getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is very similar to the one in ConvConverter, can you make a common function for both?

Copy link
Contributor Author

@pabloantoniom pabloantoniom Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the function it's very similar, like I mentioned in the PR description, but splitting it's required because the OpConversionPattern does not like that I want to pass:

  • Conv2DOp
  • Conv3DOp
  • CustomOp

To the same function. The problem is with ConvNDOp and CustomOp, they are not compatible.

This is a terrible hack and I'm happy to hear any suggestion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use templates to reuse code?

//===----------------------------------------------------------------------===//
//
// This pass is a downstream version of
// mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, can you add a commit hash to have a reference of when we copied that code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@justinrosner justinrosner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you done some running of the bwd_data e2e tests to verify that this works? e.g., https://github.com/ROCm/rocMLIR/blob/develop/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir

return op->emitError("should have 1 result");

// Verify all required attributes are present. group is optional.
for (std::string attrName : {"pad", "stride", "dilation", "conv_kind"}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we now have conv_bwd_data and conv_bwd_weight we no longer need conv_kind. It was initially added to distinguish between bwd_weight and bwd_data when both were using transpose.conv2d.

We can remove all references to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, I forgot to do this! I have pushed a fix and now we just use the customOp operator_name instea of conv_kind.

auto operands = adaptor.getOperands();
auto loc = op->getLoc();
auto *context = op->getContext();
auto input = operands[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have functions like this that would handle this discrepancy:

OpOperand *ConvBwdDataOp::getOutArgument() { return &(*this)->getOpOperand(1); }

ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());

// Translate acc_type, padding and stride attributes.
// TODO: Should we have all those attribute names as a constant and share
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we file a case to track this if this is actually something that we want to do? Just so that this doesn't get lost as a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we file a case to track this if this is actually something that we want to do? Just so that this doesn't get lost as a TODO?

Added ticket: https://github.com/ROCm/rocMLIR-internal/issues/2037

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like these either of these functions are handling input padding or dilations (both of which can be set on MIGraphX ops). This may end up needing the changes in #2007 that add support for these extra attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would be place to add the changes from #2007. It should be straightforward as RocmlirCustomTosaDecompose.cpp very similar to upstream TosaDecomposeTransposeConv.cpp which is where the changes would go.

@pabloantoniom
Copy link
Contributor Author

Have you done some running of the bwd_data e2e tests to verify that this works? e.g., https://github.com/ROCm/rocMLIR/blob/develop/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir

Yep, seems to work:

$ LIT_FILTER=mixr-bwd-data-conv-dilation1-stride2 ninja check-rocmlir
[3/228] Running utility command for fusion-e2e-tests
DONE!
[227/228] Running the RocMLIR regression tests

Testing Time: 42.91s

Total Discovered Tests: 1020
  Excluded: 1019 (99.90%)
  Passed  :    1 (0.10%)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants