-
Couldn't load subscription status.
- Fork 29
feat(tosa): sink tosa ops through tosa.concat #671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aie-public
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, especially for adding so many tests.
It looks generally good to me, but I will need to think about the reshape again, the current implementation seems to only handle some cases
| // elementwise | ||
| patterns.add< | ||
| SinkSpecificOp<tosa::AbsOp>, SinkSpecificOp<tosa::BitwiseNotOp>, | ||
| SinkSpecificOp<tosa::CeilOp>, SinkSpecificOp<tosa::ClampOp>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not sink Clamp or Floor, as they are in the form of Relu often fused with the previous op, which may be prevented by the sinking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same is also true for the LeakyRelu decomposition, but this is harder to prevents, as it consists out of multiple ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a much greater benefit of the sinking can happen, when the previous op can also be sunk through the concat but with the currect design that triggers only when the activation is sunk first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but we can not always sink the other ops.
I am fine with how it is right now, if it causes problems downstream we can still adjust it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ehsan-toosi What do you think about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jorickert This is actually a very solid point. This is gonna break all fusion when there is conv and relu/lrelu and concat situations. We should find a unit test that has conv->relu->concat and show that is better if we move relu after concat? but I'm suspicious that we are going underperform. I agree with @jorickert that we have to disable it for ReLU and LeakyReLU until we have the support to sinkdown for conv as well. After that we can enable these two ops as well again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can always control this by first doing matching of ops+activation and then latter doing the sinking
This is correct. Only reshapes that extend or lessen the shape with dimensions of rank 1 will be transformed. This seemed to be enough for real-world tests, so I left it that way, with the possibility to handle more advanced cases later. Maybe I can add a comment explaining this. |
b3b89b2 to
4a5a491
Compare
|
All comments should be addressed now. |
6c863cd to
b6522a8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did an high-level review and checked the transformation. Just one major comment which was using trait as you and Jonas already discussed as well.
I'll do the detailed code review if I find some time today.
| func.func @reshape_complex_match(%arg0: tensor<1x42x1x1xbf16>, %arg1: tensor<1x42x1x1xbf16>, %arg2: tensor<12x42x1x1xbf16>, %arg3: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> { | ||
| %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> | ||
| %1 = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32> | ||
| %2 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> | ||
| %3 = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> | ||
| %4 = tosa.transpose %3, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> | ||
| %5 = tosa.reshape %4 {new_shape = array<i64: 1, 1, 42>} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16> | ||
| %6 = tosa.reshape %arg2 {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> | ||
| %7 = tosa.transpose %6, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> | ||
| %8 = tosa.matmul %5, %7 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16> | ||
| %9 = tosa.reshape %8 {new_shape = array<i64: 1, 1, 1, 12>} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16> | ||
| %10 = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> | ||
| %11 = tosa.transpose %10, %1 : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> | ||
| %12 = tosa.reshape %11 {new_shape = array<i64: 1, 1, 42>} : (tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x42xbf16> | ||
| %13 = tosa.reshape %arg3 {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> | ||
| %14 = tosa.transpose %13, %0 : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> | ||
| %15 = tosa.matmul %12, %14 : (tensor<1x1x42xbf16>, tensor<1x42x12xbf16>) -> tensor<1x1x12xbf16> | ||
| %16 = tosa.reshape %15 {new_shape = array<i64: 1, 1, 1, 12>} : (tensor<1x1x12xbf16>) -> tensor<1x1x1x12xbf16> | ||
| %17 = tosa.concat %9, %16 {axis = 1 : i32} : (tensor<1x1x1x12xbf16>, tensor<1x1x1x12xbf16>) -> tensor<1x2x1x12xbf16> | ||
| return %17 : tensor<1x2x1x12xbf16> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @reshape_complex_match | ||
| // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_1_:%.+]]: tensor<1x42x1x1xbf16>, [[PARAM_2_:%.+]]: tensor<12x42x1x1xbf16>, [[PARAM_3_:%.+]]: tensor<12x42x1x1xbf16>) -> tensor<1x2x1x12xbf16> { | ||
| // CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> | ||
| // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32> | ||
| // CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> | ||
| // CHECK-NOT: separator of consecutive DAGs | ||
| // CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> | ||
| // CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> | ||
| // CHECK-NOT: separator of consecutive DAGs | ||
| // CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_4_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> | ||
| // CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1, 1, 1, 42>} : (tensor<1x42x1x1xbf16>) -> tensor<1x1x1x1x1x42xbf16> | ||
| // CHECK-NOT: separator of consecutive DAGs | ||
| // CHECK-DAG: [[VAR_7_:%.+]] = tosa.transpose [[VAR_6_]], [[VAR_1_]] : (tensor<1x1x1x1x1x42xbf16>, tensor<6xi32>) -> tensor<1x1x1x1x1x42xbf16> | ||
| // CHECK-DAG: [[VAR_8_:%.+]] = tosa.reshape [[PARAM_3_]] {new_shape = array<i64: 1, 12, 42>} : (tensor<12x42x1x1xbf16>) -> tensor<1x12x42xbf16> | ||
| // CHECK-NOT: separator of consecutive DAGs | ||
| // CHECK-DAG: [[VAR_9_:%.+]] = tosa.transpose [[VAR_8_]], [[VAR_0_]] : (tensor<1x12x42xbf16>, tensor<3xi32>) -> tensor<1x42x12xbf16> | ||
| // CHECK-DAG: [[VAR_10_:%.+]] = tosa.concat [[VAR_3_]], [[VAR_7_]] {axis = 3 : i32} : (tensor<1x1x1x1x1x42xbf16>, tensor<1x1x1x1x1x42xbf16>) -> tensor<1x1x1x2x1x42xbf16> | ||
| // CHECK-NOT: separator of consecutive DAGs | ||
| // CHECK-DAG: [[VAR_11_:%.+]] = tosa.reshape [[VAR_10_]] {new_shape = array<i64: 2, 1, 42>} : (tensor<1x1x1x2x1x42xbf16>) -> tensor<2x1x42xbf16> | ||
| // CHECK-DAG: [[VAR_12_:%.+]] = tosa.concat [[VAR_5_]], [[VAR_9_]] {axis = 0 : i32} : (tensor<1x42x12xbf16>, tensor<1x42x12xbf16>) -> tensor<2x42x12xbf16> | ||
| // CHECK: [[VAR_13_:%.+]] = tosa.matmul [[VAR_11_]], [[VAR_12_]] : (tensor<2x1x42xbf16>, tensor<2x42x12xbf16>) -> tensor<2x1x12xbf16> | ||
| // CHECK: [[VAR_14_:%.+]] = tosa.reshape [[VAR_13_]] {new_shape = array<i64: 1, 2, 1, 12>} : (tensor<2x1x12xbf16>) -> tensor<1x2x1x12xbf16> | ||
| // CHECK: return [[VAR_14_]] : tensor<1x2x1x12xbf16> | ||
| // CHECK: } | ||
| // ----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this test doing? could you write an explanation/comment please? it's hard to follow the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment now.
| %0 = tosa.select %arg0, %arg1, %arg2 : (!select_type, !in_type, !in_type) -> !in_type | ||
| %1 = tosa.select %arg0, %arg2, %arg1 : (!select_type, !in_type, !in_type) -> !in_type | ||
| %2 = tosa.concat %0, %1 {axis = 0 : i32} : (!in_type, !in_type) -> !out_type | ||
| return %2 : !out_type | ||
| } | ||
| // CHECK-LABEL: func.func @switch_op_concat | ||
| // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x8xi1>, [[PARAM_1_:%.+]]: tensor<1x8x8xf32>, [[PARAM_2_:%.+]]: tensor<1x8x8xf32>) -> tensor<2x8x8xf32> { | ||
| // CHECK-DAG: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]] {axis = 0 : i32} : (tensor<1x8x8xi1>, tensor<1x8x8xi1>) -> tensor<2x8x8xi1> | ||
| // CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> | ||
| // CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[PARAM_2_]], [[PARAM_1_]] {axis = 0 : i32} : (tensor<1x8x8xf32>, tensor<1x8x8xf32>) -> tensor<2x8x8xf32> | ||
| // CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<2x8x8xi1>, tensor<2x8x8xf32>, tensor<2x8x8xf32>) -> tensor<2x8x8xf32> | ||
| // CHECK: return [[VAR_3_]] : tensor<2x8x8xf32> | ||
| // CHECK: } | ||
| // ----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test with select with different first operand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean, not using %arg0 in both cases? Since the transformation currently concatenates %arg0 two times, it doesn't "know" or check that the operand is different. The outcome would be the same. Do I miss a transformation that should not be valid?
| patterns.add< | ||
| SinkSpecificOp<tosa::AbsOp>, SinkSpecificOp<tosa::BitwiseNotOp>, | ||
| SinkSpecificOp<tosa::CeilOp>, SinkSpecificOp<tosa::ClampOp>, | ||
| SinkSpecificOp<tosa::ClzOp>, SinkSpecificOp<tosa::CosOp>, | ||
| SinkSpecificOp<tosa::EqualOp>, SinkSpecificOp<tosa::ErfOp>, | ||
| SinkSpecificOp<tosa::ExpOp>, SinkSpecificOp<tosa::FloorOp>, | ||
| SinkSpecificOp<tosa::GreaterOp>, SinkSpecificOp<tosa::GreaterEqualOp>, | ||
| SinkSpecificOp<tosa::LogOp>, SinkSpecificOp<tosa::LogicalAndOp>, | ||
| SinkSpecificOp<tosa::LogicalNotOp>, SinkSpecificOp<tosa::LogicalOrOp>, | ||
| SinkSpecificOp<tosa::NegateOp>, SinkSpecificOp<tosa::ReciprocalOp>, | ||
| SinkSpecificOp<tosa::RsqrtOp>, SinkSpecificOp<tosa::SelectOp>, | ||
| SinkSpecificOp<tosa::SigmoidOp>, SinkSpecificOp<tosa::SinOp>, | ||
| SinkSpecificOp<tosa::TanhOp>>(ctx, /*benefit=*/2); | ||
| patterns.add<SinkElementwiseBroadcastableOp<tosa::AddOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::ArithmeticRightShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseAndOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseOrOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseXorOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::IntDivOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalLeftShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalRightShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalXorOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::MaximumOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::MinimumOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::PowOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::SubOp>>(ctx, 2); | ||
| // reduce | ||
| patterns | ||
| .add<SinkReduceOp<tosa::ReduceAllOp>, SinkReduceOp<tosa::ReduceAnyOp>, | ||
| SinkReduceOp<tosa::ReduceMaxOp>, SinkReduceOp<tosa::ReduceMinOp>, | ||
| SinkReduceOp<tosa::ReduceProdOp>, SinkReduceOp<tosa::ReduceSumOp>>( | ||
| ctx, 2); | ||
| // others | ||
| patterns.add<SinkMatmulOp>(ctx, 2); | ||
| patterns.add<SinkReshapeOp>(ctx, 2); | ||
| patterns.add<SinkGenericOp>(ctx, 1, stats, os); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move these to populateSinkInputOpsThroughConcatPatterns function and expose it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a private class method? I have done it now.
| SinkSpecificOp<tosa::LogicalNotOp>, SinkSpecificOp<tosa::LogicalOrOp>, | ||
| SinkSpecificOp<tosa::NegateOp>, SinkSpecificOp<tosa::ReciprocalOp>, | ||
| SinkSpecificOp<tosa::RsqrtOp>, SinkSpecificOp<tosa::SelectOp>, | ||
| SinkSpecificOp<tosa::SigmoidOp>, SinkSpecificOp<tosa::SinOp>, | ||
| SinkSpecificOp<tosa::TanhOp>>(ctx, /*benefit=*/2); | ||
| patterns.add<SinkElementwiseBroadcastableOp<tosa::AddOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::ArithmeticRightShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseAndOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseOrOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::BitwiseXorOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::IntDivOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalLeftShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalRightShiftOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::LogicalXorOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::MaximumOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::MinimumOp>, | ||
| SinkElementwiseBroadcastableOp<tosa::PowOp>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be really good if you define traits for the tosa ops and then rewrite them using
OpTraitRewritePattern
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this possible? I'm matching only tosa::concatOp. OpTraitRewritePattern seems to work for transformations that what to match a whole trait.
3fc814b to
668aed2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice! I have yet to finish the review of the reshape, I'll come back ASAP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I would like just to have a few more reshape tests
c11c500 to
33a4759
Compare
33a4759 to
2fda478
Compare
2fda478 to
0803705
Compare
0803705 to
ba93cf8
Compare
|
Changes from 0803705 to ba93cf8:
|
Currently supports elementwise and reduce operations, additional to matmul and reshape (within a limited set).
Could be extended to other operations later.