-
Notifications
You must be signed in to change notification settings - Fork 780
Harmonize *ScaledMMAAttr operand order and drop MMAFragment
#22465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Very harmonious, this PR restores tranquility. |
079ee69 to
305f0c2
Compare
Signed-off-by: Benoit Jacob <[email protected]>
305f0c2 to
1a009cf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not participate in the discussion about what the order should be, as there are reordering changes in lit tests. The change itself looks good to me.
| using MMAIntrinsicTy = decltype(mma.getIntrinsic()); | ||
| const bool isScaled = std::is_same<MMAIntrinsicTy, ScaledMMAIntrinsic>::value; | ||
| const bool isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx); | ||
| const bool isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx); | ||
| const bool isAcc = isIntrinsicAcc<MMAIntrinsicTy>(operandIdx); | ||
| const bool isLhsScale = isIntrinsicLhsScale<MMAIntrinsicTy>(operandIdx); | ||
| const bool isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks much better than magic numbers to me.
| ValueRange inputs = linalgOp->getOperands(); | ||
|
|
||
| SmallVector<Type> eltTypes; | ||
| smmaKind.getElementTypes(eltTypes); | ||
| if (cast<RankedTensorType>(inputs[0].getType()).getElementType() != | ||
| eltTypes[0] || | ||
| cast<RankedTensorType>(inputs[2].getType()).getElementType() != | ||
| eltTypes[2] || | ||
| cast<RankedTensorType>(inputs[4].getType()).getElementType() != | ||
| eltTypes[4]) { | ||
| return failure(); | ||
| for (int i : | ||
| {kScaledMMAOperandLhs, kScaledMMAOperandRhs, kScaledMMAOperandAcc}) { | ||
| if (cast<RankedTensorType>(inputs[i].getType()).getElementType() != | ||
| eltTypes[i]) { | ||
| return failure(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks much better to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR does multiple things that were easier done all at once:
*ScaledMMAAttroperand order:ScaledMMAAttrwaslhs, lhs_scale, rhs, rhs_scale, while the operand order ofDataTiledScaledMMAAttrwaslhs, rhs, lhs_scale, rhs_scale.ScaledMMAAttrto match theDataTiledScaledMMAAttrconvention. This propagates to a change of operand order in the enclosinginner_tiledops.MMAFragment:MMAFragment, that had unclear semantics: the enum values were sometimes used as opaque symbolic enums to refer to operand by "role", e.g. "Lhs", and sometimes used as the underlying integer values as operand indices, e.g. "Rhs == 1". This was originally OK as all MMA-like ops had the same 3 operands Lhs, Rhs, Acc. But when ScaledMMAAttr was introduced, that didn't... scale: now the preexisting enum value "Rhs == 1" didn't equal anymore the corresponding operand index under thelhs, lhs_scale, rhs, rhs_scaleconvention (1 != 2) and even regardless of convention, the enum value "Acc ==2" never corresponded to operand index anymore (2 != 4).operandIndex.