Skip to content

Lowering to linalg for AtenCol2ImOp#4012

Open
gdehame wants to merge 4 commits intollvm:mainfrom
gdehame:col2im-linalg-lowering
Open

Lowering to linalg for AtenCol2ImOp#4012
gdehame wants to merge 4 commits intollvm:mainfrom
gdehame:col2im-linalg-lowering

Conversation

@gdehame
Copy link

@gdehame gdehame commented Feb 10, 2025

Added a lowering to linalg for the torch.aten.col2im operation.
Added a unit test to verify the lowering.

    Added a lowering to linalg for the torch.aten.col2im operation.
    Added a unit test to verify the lowering.
@gdehame gdehame force-pushed the col2im-linalg-lowering branch from 1a2ce71 to 5158529 Compare January 24, 2026 10:58
@sahas3 sahas3 requested review from sahas3 and zjgarvey January 27, 2026 00:13
@sahas3
Copy link
Member

sahas3 commented Jan 27, 2026

Hi @gdehame, thanks for the submission.

I'm yet to review the code -- just leaving one high-level feedback. How does this op appear in the IR through the fx.export_and_import API usage? With the source pytorch program, it'll be good to also add an e2e test for numerical verification as well.

Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a first pass and left some comments. Thanks!

Comment on lines 2804 to 2810
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!(col2imOp.getOutputSize().getDefiningOp() &&
isa<Torch::PrimListConstructOp>(
col2imOp.getOutputSize().getDefiningOp())))
return failure();
Torch::PrimListConstructOp outputSizes = cast<Torch::PrimListConstructOp>(
col2imOp.getOutputSize().getDefiningOp());
auto outputSizes = col2imOp.getOutputSize().getDefiningOp<Torch::PrimListConstructOp>();
if (!outputSizes)
return failure();

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 2811 to 2828
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also similar suggestion about grabbing the def op at once and verifying it exists can be implemented as mentioned above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here and elsewhere too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Type elementType = outputType.getElementType();
Value outputBuffer = tensor::EmptyOp::create(
rewriter, col2imOp->getLoc(),
ArrayRef<int64_t>{outputType.getDimSize(0), outputType.getDimSize(1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think assumption here is that outputType dims 0 and 1 are static but that is not explicitly checked anywhere.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check, thanks!

Comment on lines 2967 to 2973
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These nested tertiary ops are hard to read (here and in other places). Can you break this down?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 2992 to 2997
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be checked before starting to create any new ops in the IR and failure() can be returned for this unsupported case instead of an assert.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

integer/complex types are not being tested. I'd suggest adding e2e tests for testing those combinations

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started doing so and realized the integer case is bugged due to using arith.constant to generate a 0 constant of possibly (un)signed int which is not allowed.
I'll work on this further next week-end

Copy link
Author

@gdehame gdehame Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR with additional tests and fixed the lowering

@gdehame
Copy link
Author

gdehame commented Jan 28, 2026

Hi @gdehame, thanks for the submission.

I'm yet to review the code -- just leaving one high-level feedback. How does this op appear in the IR through the fx.export_and_import API usage? With the source pytorch program, it'll be good to also add an e2e test for numerical verification as well.

This came up during an internship I was doing last year, I completely forgot about the exact setup and example. I'll try to recreate one during the week-end

@gdehame
Copy link
Author

gdehame commented Feb 1, 2026

Hi @gdehame, thanks for the submission.
I'm yet to review the code -- just leaving one high-level feedback. How does this op appear in the IR through the fx.export_and_import API usage? With the source pytorch program, it'll be good to also add an e2e test for numerical verification as well.

This came up during an internship I was doing last year, I completely forgot about the exact setup and example. I'll try to recreate one during the week-end

Here is an example:

from torch_mlir import fx
from torch_mlir.compiler_utils import OutputType
import torch
import numpy as np

fold = torch.nn.Fold((14, 30), (2, 2), (1, 1), (1, 1), (2, 2))

print(
    fx.export_and_import(
        torch.export.export(fold, (torch.tensor(np.zeros((1, 12, 128))),)),
        output_type=OutputType.LINALG_ON_TENSORS,
    )
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments