Skip to content

Commit b0103f6

Browse files
authored
Harmonize *ScaledMMAAttr operand order and drop MMAFragment (#22465)
This PR does multiple things that were easier done all at once: 1. Harmonize `*ScaledMMAAttr` operand order: * The operand order of `ScaledMMAAttr` was `lhs, lhs_scale, rhs, rhs_scale`, while the operand order of `DataTiledScaledMMAAttr` was `lhs, rhs, lhs_scale, rhs_scale`. * This PR changes `ScaledMMAAttr` to match the `DataTiledScaledMMAAttr` convention. This propagates to a change of operand order in the enclosing `inner_tiled` ops. 2. Drop `MMAFragment`: * There used to be a TableGen enum `MMAFragment`, that had unclear semantics: the enum values were sometimes used as opaque symbolic enums to refer to operand by "role", e.g. "Lhs", and sometimes used as the underlying integer values as operand indices, e.g. "Rhs == 1". This was originally OK as all MMA-like ops had the same 3 operands Lhs, Rhs, Acc. But when ScaledMMAAttr was introduced, that didn't... scale: now the preexisting enum value "Rhs == 1" didn't equal anymore the corresponding operand index under the `lhs, lhs_scale, rhs, rhs_scale` convention (1 != 2) and even regardless of convention, the enum value "Acc ==2" never corresponded to operand index anymore (2 != 4). * This PR drops MMAFragment and generalizes the use of plain integer `operandIndex`. * This is made reasonable to implement by the harmonization of operand orders (above 1.). If we tried doing this without 1., then we would feel more of a need for opaque "role" enums to replace MMAFragment without having to fix up operand indices between two conventions. 3. Drop some logic that only existed to make up for the discrepancy between convensions and the fuzzy semantics of MMAFragment. Signed-off-by: Benoit Jacob <[email protected]>
1 parent bd3af49 commit b0103f6

23 files changed

+401
-437
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMUkernelBitcodeSupport.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ getCInnermostStaticCrossIntrinsicDim(IREE::Codegen::InnerTiledOp op) {
8181
}
8282
auto mma = cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
8383
IREE::Codegen::TileSwizzle accSwizzle =
84-
getSwizzle(mma, IREE::GPU::MMAFragment::Acc);
84+
getSwizzle(mma, IREE::GPU::kMMAOperandAcc);
8585
SmallVector<IREE::Codegen::TileSwizzle::Dim> swizzleDims;
8686
for (IREE::Codegen::TileSwizzle::ExpandShapeDimVectorType group :
8787
accSwizzle.expandShape) {

compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,26 +320,24 @@ MlirAttribute ireeGPULoweringConfigAttrGetMmaKind(MlirAttribute attr) {
320320
}
321321

322322
ireeGPUMMASingleSubgroupLayout
323-
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment) {
323+
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t operandIndex) {
324324
assert((ireeAttributeIsAGPUMMAIntrinsicAttr(attr) ||
325325
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(attr)) &&
326326
"Expected MMA or VirtualMMA Intrinsic");
327327

328328
mlir::Attribute baseAttr = unwrap(attr);
329329
mlir::iree_compiler::IREE::GPU::MMASingleSubgroupLayout layout;
330-
mlir::iree_compiler::IREE::GPU::MMAFragment frag =
331-
static_cast<mlir::iree_compiler::IREE::GPU::MMAFragment>(fragment);
332330

333331
if (auto intrinsicAttr =
334332
llvm::dyn_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
335333
baseAttr)) {
336334
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
337-
intrinsicAttr.getValue(), frag);
335+
intrinsicAttr.getValue(), operandIndex);
338336
} else if (auto virtualIntrinsicAttr = llvm::dyn_cast<
339337
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
340338
baseAttr)) {
341339
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
342-
virtualIntrinsicAttr.getValue(), frag);
340+
virtualIntrinsicAttr.getValue(), operandIndex);
343341
} else {
344342
assert(false &&
345343
"Unreachable: attribute must be MMA or VirtualMMA intrinsic");

compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,17 @@ static LogicalResult isIntrinsicLayoutCompatible(
5959
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
6060
auto [accM, accN] = opInfo.getResultMNIndex();
6161
if (failed(isSubgroupLayoutCompatible(
62-
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Lhs),
62+
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandLhs),
6363
lhsLayout, lhsM, lhsK))) {
6464
return failure();
6565
}
6666
if (failed(isSubgroupLayoutCompatible(
67-
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Rhs),
67+
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandRhs),
6868
rhsLayout, rhsK, rhsN))) {
6969
return failure();
7070
}
7171
if (failed(isSubgroupLayoutCompatible(
72-
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Acc),
72+
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandAcc),
7373
accLayout, accM, accN))) {
7474
return failure();
7575
}

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -797,24 +797,24 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
797797
const GPUIntrinsicType &intrinsicA = intrinsics[qkIndex];
798798
const GPUIntrinsicType &intrinsicB = intrinsics[pvIndex];
799799
if (!matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
800-
IREE::GPU::MMAFragment::Acc),
800+
IREE::GPU::kMMAOperandAcc),
801801
getSingleSubgroupLayout(intrinsicB.mmaKind,
802-
IREE::GPU::MMAFragment::Acc))) {
802+
IREE::GPU::kMMAOperandAcc))) {
803803
continue;
804804
}
805805

806806
// Check if we can reuse the output of intrinsicA for lhs/rhs of
807807
// intrinsicB.
808808
bool canReuseAOutForBLhs =
809809
matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
810-
IREE::GPU::MMAFragment::Acc),
810+
IREE::GPU::kMMAOperandAcc),
811811
getSingleSubgroupLayout(intrinsicB.mmaKind,
812-
IREE::GPU::MMAFragment::Lhs));
812+
IREE::GPU::kMMAOperandLhs));
813813
bool canReuseAOutForBRhs =
814814
matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
815-
IREE::GPU::MMAFragment::Acc),
815+
IREE::GPU::kMMAOperandAcc),
816816
getSingleSubgroupLayout(intrinsicB.mmaKind,
817-
IREE::GPU::MMAFragment::Rhs));
817+
IREE::GPU::kMMAOperandRhs));
818818
intrinsicPairs.push_back(
819819
{intrinsicA, intrinsicB, canReuseAOutForBLhs || canReuseAOutForBRhs});
820820
}

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ module {
9090
// CHECK: iree_codegen.inner_tiled
9191
// CHECK-SAME: indexing_maps =
9292
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
93-
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
9493
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
94+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
9595
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
9696
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
9797
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>}>
98-
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
99-
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
98+
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
99+
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
100100

101101
// -----
102102

@@ -129,13 +129,13 @@ module {
129129
// CHECK: iree_codegen.inner_tiled
130130
// CHECK-SAME: indexing_maps =
131131
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
132-
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
133132
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
133+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
134134
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
135135
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
136136
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>}>
137-
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
138-
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
137+
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
138+
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
139139

140140
// -----
141141

@@ -168,13 +168,13 @@ module {
168168
// CHECK: iree_codegen.inner_tiled
169169
// CHECK-SAME: indexing_maps =
170170
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
171-
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
172171
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
172+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
173173
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
174174
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
175175
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>}>
176-
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
177-
// CHECK-SAME: : tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU>, tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU> into tensor<?x?x32x32xf32>
176+
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
177+
// CHECK-SAME: : tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU> into tensor<?x?x32x32xf32>
178178

179179
// -----
180180

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -101,39 +101,13 @@ static void interleave(TileSwizzle &swizzle, size_t srcIdx, int expandedIdx) {
101101
template <typename MMAIntrinsicTy>
102102
static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
103103
unsigned operandIdx) {
104-
IREE::GPU::MMASingleSubgroupLayout layout;
104+
IREE::GPU::MMASingleSubgroupLayout layout =
105+
IREE::GPU::getSingleSubgroupLayout(intrinsic, operandIdx);
105106
const bool isScaled =
106107
std::is_same<MMAIntrinsicTy, IREE::GPU::ScaledMMAIntrinsic>::value;
107-
const unsigned lhsIdx = 0;
108-
const unsigned rhsIdx = 1;
109-
const unsigned lhsScalesIdx = 2;
110-
const unsigned rhsScalesIdx = 3;
111-
const bool isLHSorRHS = operandIdx == lhsIdx || operandIdx == rhsIdx;
112-
if (isScaled) {
113-
// The operand mapping for `getSingleSubgroupLayout` follows a different
114-
// operand order than is used for TileSwizzle, so we need to remap the
115-
// operandIdx to get the right layout. The layouts for TileSwizzle vs.
116-
// `getSingleSubgroupLayout` are shown below:
117-
// | TileSwizzle | getSingleSubgroupLayout
118-
// LHS | 0 | 0
119-
// RHS | 1 | 2
120-
// LHS Scales | 2 | 1
121-
// RHS Scales | 3 | 3
122-
// ACC | 4 | 4
123-
// TODO(Max191): Decide on a consistent operand order for both.
124-
int64_t layoutOperandIdx = operandIdx;
125-
if (operandIdx == rhsIdx) {
126-
layoutOperandIdx = 2;
127-
} else if (operandIdx == lhsScalesIdx) {
128-
layoutOperandIdx = 1;
129-
}
130-
layout = IREE::GPU::getSingleSubgroupLayout(
131-
static_cast<ScaledMMAIntrinsic>(intrinsic), layoutOperandIdx);
132-
} else {
133-
layout = IREE::GPU::getSingleSubgroupLayout(
134-
static_cast<MMAIntrinsic>(intrinsic),
135-
static_cast<IREE::GPU::MMAFragment>(operandIdx));
136-
}
108+
const bool isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
109+
const bool isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
110+
const bool isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);
137111

138112
// MMASingleSubgroupLayout has non-transposed RHS and RHS scales, but
139113
// TileSwizzle has transposed RHS and RHS scales, so reorder the `layout`
@@ -143,7 +117,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
143117
// rotate right by 1 element to swap [K, Kb] and N.
144118
std::rotate(v.begin(), v.end() - 1, v.end());
145119
};
146-
if (operandIdx == rhsIdx || (isScaled && operandIdx == rhsScalesIdx)) {
120+
if (isRhs || isRhsScale) {
147121
swapRHSKAndN(layout.outer);
148122
swapRHSKAndN(layout.thread);
149123
swapRHSKAndN(layout.tstrides);
@@ -155,7 +129,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
155129
// All other operands (and LHS/RHS for non-scaled matmuls) have 2 source
156130
// dimensions. These correspond to the arrays in `layout` all having a
157131
// matching size. Let's just guard that assumption with one assert here.
158-
const unsigned numSrcDims = isScaled && isLHSorRHS ? 3 : 2;
132+
const unsigned numSrcDims = isScaled && (isLhs || isRhs) ? 3 : 2;
159133
assert(layout.thread.size() == numSrcDims &&
160134
"expected layout rank to match the number of source dims");
161135
swizzle.expandShape.resize(numSrcDims);
@@ -233,16 +207,14 @@ static size_t getInnermostNonInternalDimIdx(
233207
template <typename MMAAttrTy>
234208
static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
235209
TileSwizzle swizzle = getIntrinsicSwizzle(mma.getIntrinsic(), operandIdx);
236-
const bool isScaled =
237-
std::is_same<MMAAttrTy, IREE::GPU::DataTiledScaledMMAAttr>::value;
238-
const unsigned lhsIdx = 0;
239-
const unsigned rhsIdx = 1;
240-
const unsigned lhsScalesIdx = 2;
241-
const unsigned rhsScalesIdx = 3;
242-
const unsigned accIdx = isScaled ? 4 : 2;
243-
const bool isRhsScales = isScaled && operandIdx == rhsScalesIdx;
244-
const bool isLhsScales = isScaled && operandIdx == lhsScalesIdx;
245-
if (operandIdx == lhsIdx || isLhsScales) {
210+
using MMAIntrinsicTy = decltype(mma.getIntrinsic());
211+
const bool isScaled = std::is_same<MMAIntrinsicTy, ScaledMMAIntrinsic>::value;
212+
const bool isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
213+
const bool isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
214+
const bool isAcc = isIntrinsicAcc<MMAIntrinsicTy>(operandIdx);
215+
const bool isLhsScale = isIntrinsicLhsScale<MMAIntrinsicTy>(operandIdx);
216+
const bool isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);
217+
if (isLhs || isLhsScale) {
246218
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
247219
// Unroll on K with interleaving, then on M.
248220
if (mma.getIntrinsicsK() > 1) {
@@ -253,10 +225,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
253225
// the unrolled scales with each vector load, so we need to interleave at
254226
// the very last dimension for the scales. For the LHS, we load in blocks,
255227
// so we don't need to interleave.
256-
if (isLhsScales) {
228+
if (isLhsScale) {
257229
interleavingIdx = swizzle.expandShape[1].size() - 1;
258230
}
259-
if (!isScaled || isLhsScales) {
231+
if (!isScaled || isLhsScale) {
260232
interleave(swizzle, 1, interleavingIdx);
261233
}
262234
}
@@ -272,7 +244,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
272244
mma.getSubgroupsM() * mma.getSubgroupsN());
273245
expand(swizzle, 0, dim);
274246
}
275-
} else if (operandIdx == rhsIdx || isRhsScales) {
247+
} else if (isRhs || isRhsScale) {
276248
// B-matrix (RHS). Since the pack ops already took care of transposing B,
277249
// source dimensions are N (index 0) and K (index 1).
278250
// Unroll on K with interleaving, then on N.
@@ -282,10 +254,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
282254
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
283255
// Like with the LHS above, we want to interleave such that we load all
284256
// the unrolled scales with each vector load.
285-
if (isRhsScales) {
257+
if (isRhsScale) {
286258
interleavingIdx = swizzle.expandShape[1].size() - 1;
287259
}
288-
if (!isScaled || isRhsScales) {
260+
if (!isScaled || isRhsScale) {
289261
interleave(swizzle, 1, interleavingIdx);
290262
}
291263
}
@@ -295,7 +267,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
295267
if (mma.getSubgroupsN() > 1) {
296268
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsN()});
297269
}
298-
} else if (operandIdx == accIdx) {
270+
} else if (isAcc) {
299271
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
300272
// 1). Unroll on N, then on M.
301273
if (mma.getIntrinsicsN() > 1) {
@@ -319,9 +291,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledScaledMMAAttr scaledMma,
319291
return getSwizzleImpl(scaledMma, operandIdx);
320292
}
321293

322-
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
323-
IREE::GPU::MMAFragment fragment) {
324-
return getSwizzleImpl(mma, static_cast<unsigned>(fragment));
294+
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, int operandIndex) {
295+
return getSwizzleImpl(mma, operandIndex);
325296
}
326297

327298
/// Remove the expanded dimensions for this index and update the permutation by

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ SmallVector<int64_t> sliceSwizzledShape(
2222
/// Returns the swizzle for the full data-tiled-mma tile, including all the
2323
/// relevant unrolling and expansion factors.
2424
Codegen::TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
25-
IREE::GPU::MMAFragment fragment);
25+
int operandIndex);
2626

2727
/// Returns the swizzle for the full data-tiled-scaled-mma tile, including all
2828
/// the relevant unrolling and expansion factors.

0 commit comments

Comments
 (0)