Skip to content

Conversation

@gerion-amd
Copy link

Currently supports elementwise and reduce operations, additional to matmul and reshape (within a limited set).
Could be extended to other operations later.

Copy link
Contributor

@jorickert jorickert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, especially for adding so many tests.

It looks generally good to me, but I will need to think about the reshape again, the current implementation seems to only handle some cases

// elementwise
patterns.add<
SinkSpecificOp<tosa::AbsOp>, SinkSpecificOp<tosa::BitwiseNotOp>,
SinkSpecificOp<tosa::CeilOp>, SinkSpecificOp<tosa::ClampOp>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not sink Clamp or Floor, as they are in the form of Relu often fused with the previous op, which may be prevented by the sinking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same is also true for the LeakyRelu decomposition, but this is harder to prevents, as it consists out of multiple ops

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a much greater benefit of the sinking can happen, when the previous op can also be sunk through the concat but with the currect design that triggers only when the activation is sunk first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we can not always sink the other ops.
I am fine with how it is right now, if it causes problems downstream we can still adjust it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ehsan-toosi What do you think about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorickert This is actually a very solid point. This is gonna break all fusion when there is conv and relu/lrelu and concat situations. We should find a unit test that has conv->relu->concat and show that is better if we move relu after concat? but I'm suspicious that we are going underperform. I agree with @jorickert that we have to disable it for ReLU and LeakyReLU until we have the support to sinkdown for conv as well. After that we can enable these two ops as well again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always control this by first doing matching of ops+activation and then latter doing the sinking

@gerion-amd
Copy link
Author

It looks generally good to me, but I will need to think about the reshape again, the current implementation seems to only handle some cases

This is correct. Only reshapes that extend or lessen the shape with dimensions of rank 1 will be transformed. This seemed to be enough for real-world tests, so I left it that way, with the possibility to handle more advanced cases later. Maybe I can add a comment explaining this.

@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch 2 times, most recently from b3b89b2 to 4a5a491 Compare October 2, 2025 12:55
@gerion-amd
Copy link
Author

All comments should be addressed now.

@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch 5 times, most recently from 6c863cd to b6522a8 Compare October 6, 2025 09:12
Copy link
Contributor

@ehsan-toosi ehsan-toosi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an high-level review and checked the transformation. Just one major comment which was using trait as you and Jonas already discussed as well.
I'll do the detailed code review if I find some time today.

Comment on lines 278 to 419
func.func @reshape_complex_match(%arg0: tensor<1x42x1x1xbf16>, %arg1: tensor<1x42x1x1xbf16>, %arg2: tensor<12x42x1x1xbf16>, %arg3: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> {
%0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
%1 = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32>
%2 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%3 = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16>
%4 = tosa.transpose %3, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16>
%5 = tosa.reshape %4 {new_shape = array<i64: 1, 1, 42>} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16>
%6 = tosa.reshape %arg2 {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16>
%7 = tosa.transpose %6, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16>
%8 = tosa.matmul %5, %7 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16>
%9 = tosa.reshape %8 {new_shape = array<i64: 1, 1, 1, 12>} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16>
%10 = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16>
%11 = tosa.transpose %10, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16>
%12 = tosa.reshape %11 {new_shape = array<i64: 1, 1, 42>} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16>
%13 = tosa.reshape %arg3 {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16>
%14 = tosa.transpose %13, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16>
%15 = tosa.matmul %12, %14 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16>
%16 = tosa.reshape %15 {new_shape = array<i64: 1, 1, 1, 12>} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16>
%17 = tosa.concat %9, %16 {axis = 1 : i32} : (tensor<1x1x1x12xbf16>, tensor<1x1x1x12xbf16>) -> tensor<1x2x1x12xbf16>
return %17 : tensor<1x2x1x12xbf16>
}

// CHECK-LABEL: func.func @reshape_complex_match
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_1_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_2_:%.+]]: tensor<12x42x1x1xbf16>, [[PARAM_3_:%.+]]: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> {
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32>
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16>
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_4_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16>
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.transpose [[VAR_6_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16>
// CHECK-DAG: [[VAR_8_:%.+]] = tosa.reshape [[PARAM_3_]] {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_9_:%.+]] = tosa.transpose [[VAR_8_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16>
// CHECK-DAG: [[VAR_10_:%.+]] = tosa.concat [[VAR_3_]], [[VAR_7_]] {axis = 3 : i32} : (tensor<1x1x1x1x1x42xbf16>, tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x1x2x1x42xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_11_:%.+]] = tosa.reshape [[VAR_10_]] {new_shape = array<i64: 2, 1, 42>} : (tensor<1x1x1x2x1x42xbf16>) -> tensor<2x1x42xbf16>
// CHECK-DAG: [[VAR_12_:%.+]] = tosa.concat [[VAR_5_]], [[VAR_9_]] {axis = 0 : i32} : (tensor<1x42x12xbf16>, tensor<1x42x12xbf16>) -> tensor<2x42x12xbf16>
// CHECK: [[VAR_13_:%.+]] = tosa.matmul [[VAR_11_]], [[VAR_12_]] : (tensor<2x1x42xbf16>, tensor<2x42x12xbf16>) -> tensor<2x1x12xbf16>
// CHECK: [[VAR_14_:%.+]] = tosa.reshape [[VAR_13_]] {new_shape = array<i64: 1, 2, 1, 12>} : (tensor<2x1x12xbf16>) -> tensor<1x2x1x12xbf16>
// CHECK: return [[VAR_14_]] : tensor<1x2x1x12xbf16>
// CHECK: }
// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test doing? could you write an explanation/comment please? it's hard to follow the IR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment now.

Comment on lines +680 to +788
%0 = tosa.select %arg0, %arg1, %arg2 : (!select_type, !in_type, !in_type) -> !in_type
%1 = tosa.select %arg0, %arg2, %arg1 : (!select_type, !in_type, !in_type) -> !in_type
%2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type
return %2 : !out_type
}
// CHECK-LABEL: func.func @switch_op_concat
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<2x8x8xi1>
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[PARAM_2_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32>
// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<2x8x8xi1>, tensor<2x8x8xf32>, tensor<2x8x8xf32>) -> tensor<2x8x8xf32>
// CHECK: return [[VAR_3_]] : tensor<2x8x8xf32>
// CHECK: }
// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with select with different first operand?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean, not using %arg0 in both cases? Since the transformation currently concatenates %arg0 two times, it doesn't "know" or check that the operand is different. The outcome would be the same. Do I miss a transformation that should not be valid?

Comment on lines 365 to 411
patterns.add<
SinkSpecificOp<tosa::AbsOp>, SinkSpecificOp<tosa::BitwiseNotOp>,
SinkSpecificOp<tosa::CeilOp>, SinkSpecificOp<tosa::ClampOp>,
SinkSpecificOp<tosa::ClzOp>, SinkSpecificOp<tosa::CosOp>,
SinkSpecificOp<tosa::EqualOp>, SinkSpecificOp<tosa::ErfOp>,
SinkSpecificOp<tosa::ExpOp>, SinkSpecificOp<tosa::FloorOp>,
SinkSpecificOp<tosa::GreaterOp>, SinkSpecificOp<tosa::GreaterEqualOp>,
SinkSpecificOp<tosa::LogOp>, SinkSpecificOp<tosa::LogicalAndOp>,
SinkSpecificOp<tosa::LogicalNotOp>, SinkSpecificOp<tosa::LogicalOrOp>,
SinkSpecificOp<tosa::NegateOp>, SinkSpecificOp<tosa::ReciprocalOp>,
SinkSpecificOp<tosa::RsqrtOp>, SinkSpecificOp<tosa::SelectOp>,
SinkSpecificOp<tosa::SigmoidOp>, SinkSpecificOp<tosa::SinOp>,
SinkSpecificOp<tosa::TanhOp>>(ctx, /*benefit=*/2);
patterns.add<SinkElementwiseBroadcastableOp<tosa::AddOp>,
SinkElementwiseBroadcastableOp<tosa::ArithmeticRightShiftOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseAndOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseOrOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseXorOp>,
SinkElementwiseBroadcastableOp<tosa::IntDivOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalLeftShiftOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalRightShiftOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalXorOp>,
SinkElementwiseBroadcastableOp<tosa::MaximumOp>,
SinkElementwiseBroadcastableOp<tosa::MinimumOp>,
SinkElementwiseBroadcastableOp<tosa::PowOp>,
SinkElementwiseBroadcastableOp<tosa::SubOp>>(ctx, 2);
// reduce
patterns
.add<SinkReduceOp<tosa::ReduceAllOp>, SinkReduceOp<tosa::ReduceAnyOp>,
SinkReduceOp<tosa::ReduceMaxOp>, SinkReduceOp<tosa::ReduceMinOp>,
SinkReduceOp<tosa::ReduceProdOp>, SinkReduceOp<tosa::ReduceSumOp>>(
ctx, 2);
// others
patterns.add<SinkMatmulOp>(ctx, 2);
patterns.add<SinkReshapeOp>(ctx, 2);
patterns.add<SinkGenericOp>(ctx, 1, stats, os);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move these to populateSinkInputOpsThroughConcatPatterns function and expose it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a private class method? I have done it now.

Comment on lines 373 to 400
SinkSpecificOp<tosa::LogicalNotOp>, SinkSpecificOp<tosa::LogicalOrOp>,
SinkSpecificOp<tosa::NegateOp>, SinkSpecificOp<tosa::ReciprocalOp>,
SinkSpecificOp<tosa::RsqrtOp>, SinkSpecificOp<tosa::SelectOp>,
SinkSpecificOp<tosa::SigmoidOp>, SinkSpecificOp<tosa::SinOp>,
SinkSpecificOp<tosa::TanhOp>>(ctx, /*benefit=*/2);
patterns.add<SinkElementwiseBroadcastableOp<tosa::AddOp>,
SinkElementwiseBroadcastableOp<tosa::ArithmeticRightShiftOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseAndOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseOrOp>,
SinkElementwiseBroadcastableOp<tosa::BitwiseXorOp>,
SinkElementwiseBroadcastableOp<tosa::IntDivOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalLeftShiftOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalRightShiftOp>,
SinkElementwiseBroadcastableOp<tosa::LogicalXorOp>,
SinkElementwiseBroadcastableOp<tosa::MaximumOp>,
SinkElementwiseBroadcastableOp<tosa::MinimumOp>,
SinkElementwiseBroadcastableOp<tosa::PowOp>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really good if you define traits for the tosa ops and then rewrite them using
OpTraitRewritePattern

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible? I'm matching only tosa::concatOp. OpTraitRewritePattern seems to work for transformations that what to match a whole trait.

@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch 4 times, most recently from 3fc814b to 668aed2 Compare October 8, 2025 09:44
@gerion-amd gerion-amd requested a review from ehsan-toosi October 8, 2025 09:44
@franciscofd franciscofd removed the request for review from ehsan-toosi October 8, 2025 11:22
Copy link
Contributor

@roberteg16 roberteg16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice! I have yet to finish the review of the reshape, I'll come back ASAP

Copy link
Contributor

@roberteg16 roberteg16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I would like just to have a few more reshape tests

@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch 3 times, most recently from c11c500 to 33a4759 Compare October 10, 2025 19:01
@gerion-amd gerion-amd requested a review from roberteg16 October 10, 2025 19:04
@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch from 33a4759 to 2fda478 Compare October 13, 2025 08:16
@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch from 2fda478 to 0803705 Compare October 20, 2025 10:26
@gerion-amd gerion-amd force-pushed the gerion-amd.sink-ops-through-concat branch from 0803705 to ba93cf8 Compare October 20, 2025 10:27
@gerion-amd
Copy link
Author

Changes from 0803705 to ba93cf8:

  • Rename the variables with old and new prefix to beforeReshape and afterReshape to make it more specific that their position in the operation graph is meant
  • Add documention of the reshape method
  • Fix a bug for a reshape where the concat axis dimension is not of size 1. The approach is to also include the concat axis dimension size into the product and check for the same dimension before and afterwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants