Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ getCInnermostStaticCrossIntrinsicDim(IREE::Codegen::InnerTiledOp op) {
}
auto mma = cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
IREE::Codegen::TileSwizzle accSwizzle =
getSwizzle(mma, IREE::GPU::MMAFragment::Acc);
getSwizzle(mma, IREE::GPU::kMMAOperandAcc);
SmallVector<IREE::Codegen::TileSwizzle::Dim> swizzleDims;
for (IREE::Codegen::TileSwizzle::ExpandShapeDimVectorType group :
accSwizzle.expandShape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,24 @@ MlirAttribute ireeGPULoweringConfigAttrGetMmaKind(MlirAttribute attr) {
}

ireeGPUMMASingleSubgroupLayout
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment) {
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t operandIndex) {
assert((ireeAttributeIsAGPUMMAIntrinsicAttr(attr) ||
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(attr)) &&
"Expected MMA or VirtualMMA Intrinsic");

mlir::Attribute baseAttr = unwrap(attr);
mlir::iree_compiler::IREE::GPU::MMASingleSubgroupLayout layout;
mlir::iree_compiler::IREE::GPU::MMAFragment frag =
static_cast<mlir::iree_compiler::IREE::GPU::MMAFragment>(fragment);

if (auto intrinsicAttr =
llvm::dyn_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
baseAttr)) {
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
intrinsicAttr.getValue(), frag);
intrinsicAttr.getValue(), operandIndex);
} else if (auto virtualIntrinsicAttr = llvm::dyn_cast<
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
baseAttr)) {
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
virtualIntrinsicAttr.getValue(), frag);
virtualIntrinsicAttr.getValue(), operandIndex);
} else {
assert(false &&
"Unreachable: attribute must be MMA or VirtualMMA intrinsic");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ static LogicalResult isIntrinsicLayoutCompatible(
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Lhs),
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandLhs),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Rhs),
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandRhs),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Acc),
getSingleSubgroupLayout(intrinsic, IREE::GPU::kMMAOperandAcc),
accLayout, accM, accN))) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,24 +797,24 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
const GPUIntrinsicType &intrinsicA = intrinsics[qkIndex];
const GPUIntrinsicType &intrinsicB = intrinsics[pvIndex];
if (!matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
IREE::GPU::MMAFragment::Acc),
IREE::GPU::kMMAOperandAcc),
getSingleSubgroupLayout(intrinsicB.mmaKind,
IREE::GPU::MMAFragment::Acc))) {
IREE::GPU::kMMAOperandAcc))) {
continue;
}

// Check if we can reuse the output of intrinsicA for lhs/rhs of
// intrinsicB.
bool canReuseAOutForBLhs =
matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
IREE::GPU::MMAFragment::Acc),
IREE::GPU::kMMAOperandAcc),
getSingleSubgroupLayout(intrinsicB.mmaKind,
IREE::GPU::MMAFragment::Lhs));
IREE::GPU::kMMAOperandLhs));
bool canReuseAOutForBRhs =
matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
IREE::GPU::MMAFragment::Acc),
IREE::GPU::kMMAOperandAcc),
getSingleSubgroupLayout(intrinsicB.mmaKind,
IREE::GPU::MMAFragment::Rhs));
IREE::GPU::kMMAOperandRhs));
intrinsicPairs.push_back(
{intrinsicA, intrinsicB, canReuseAOutForBLhs || canReuseAOutForBRhs});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ module {
// CHECK: iree_codegen.inner_tiled
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>}>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>

// -----

Expand Down Expand Up @@ -129,13 +129,13 @@ module {
// CHECK: iree_codegen.inner_tiled
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>}>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x?x16x4x32xf4E2M1FN>, tensor<?x?x16x4xf8E8M0FNU>, tensor<?x?x16x4xf8E8M0FNU> into tensor<?x?x16x16xf32>

// -----

Expand Down Expand Up @@ -168,13 +168,13 @@ module {
// CHECK: iree_codegen.inner_tiled
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_32x32x64_B32, lhs_elem_type = f8E8M0FNU, rhs_elem_type = f8E8M0FNU, acc_elem_type = f32>}>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 0, 1>, array<i64: 2, 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU>, tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU> into tensor<?x?x32x32xf32>
// CHECK-SAME: permutations = [array<i64: 0, 1, 2>, array<i64: 2, 0, 1>, array<i64: 0, 1>, array<i64: 1, 0>, array<i64: 0, 1>]
// CHECK-SAME: : tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x?x32x2x32xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU>, tensor<?x?x32x2xf8E8M0FNU> into tensor<?x?x32x32xf32>

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,39 +101,13 @@ static void interleave(TileSwizzle &swizzle, size_t srcIdx, int expandedIdx) {
template <typename MMAIntrinsicTy>
static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
unsigned operandIdx) {
IREE::GPU::MMASingleSubgroupLayout layout;
IREE::GPU::MMASingleSubgroupLayout layout =
IREE::GPU::getSingleSubgroupLayout(intrinsic, operandIdx);
const bool isScaled =
std::is_same<MMAIntrinsicTy, IREE::GPU::ScaledMMAIntrinsic>::value;
const unsigned lhsIdx = 0;
const unsigned rhsIdx = 1;
const unsigned lhsScalesIdx = 2;
const unsigned rhsScalesIdx = 3;
const bool isLHSorRHS = operandIdx == lhsIdx || operandIdx == rhsIdx;
if (isScaled) {
// The operand mapping for `getSingleSubgroupLayout` follows a different
// operand order than is used for TileSwizzle, so we need to remap the
// operandIdx to get the right layout. The layouts for TileSwizzle vs.
// `getSingleSubgroupLayout` are shown below:
// | TileSwizzle | getSingleSubgroupLayout
// LHS | 0 | 0
// RHS | 1 | 2
// LHS Scales | 2 | 1
// RHS Scales | 3 | 3
// ACC | 4 | 4
// TODO(Max191): Decide on a consistent operand order for both.
int64_t layoutOperandIdx = operandIdx;
if (operandIdx == rhsIdx) {
layoutOperandIdx = 2;
} else if (operandIdx == lhsScalesIdx) {
layoutOperandIdx = 1;
}
layout = IREE::GPU::getSingleSubgroupLayout(
static_cast<ScaledMMAIntrinsic>(intrinsic), layoutOperandIdx);
} else {
layout = IREE::GPU::getSingleSubgroupLayout(
static_cast<MMAIntrinsic>(intrinsic),
static_cast<IREE::GPU::MMAFragment>(operandIdx));
}
const bool isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
const bool isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
const bool isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);

// MMASingleSubgroupLayout has non-transposed RHS and RHS scales, but
// TileSwizzle has transposed RHS and RHS scales, so reorder the `layout`
Expand All @@ -143,7 +117,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
// rotate right by 1 element to swap [K, Kb] and N.
std::rotate(v.begin(), v.end() - 1, v.end());
};
if (operandIdx == rhsIdx || (isScaled && operandIdx == rhsScalesIdx)) {
if (isRhs || isRhsScale) {
swapRHSKAndN(layout.outer);
swapRHSKAndN(layout.thread);
swapRHSKAndN(layout.tstrides);
Expand All @@ -155,7 +129,7 @@ static TileSwizzle getIntrinsicSwizzle(MMAIntrinsicTy intrinsic,
// All other operands (and LHS/RHS for non-scaled matmuls) have 2 source
// dimensions. These correspond to the arrays in `layout` all having a
// matching size. Let's just guard that assumption with one assert here.
const unsigned numSrcDims = isScaled && isLHSorRHS ? 3 : 2;
const unsigned numSrcDims = isScaled && (isLhs || isRhs) ? 3 : 2;
assert(layout.thread.size() == numSrcDims &&
"expected layout rank to match the number of source dims");
swizzle.expandShape.resize(numSrcDims);
Expand Down Expand Up @@ -233,16 +207,14 @@ static size_t getInnermostNonInternalDimIdx(
template <typename MMAAttrTy>
static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
TileSwizzle swizzle = getIntrinsicSwizzle(mma.getIntrinsic(), operandIdx);
const bool isScaled =
std::is_same<MMAAttrTy, IREE::GPU::DataTiledScaledMMAAttr>::value;
const unsigned lhsIdx = 0;
const unsigned rhsIdx = 1;
const unsigned lhsScalesIdx = 2;
const unsigned rhsScalesIdx = 3;
const unsigned accIdx = isScaled ? 4 : 2;
const bool isRhsScales = isScaled && operandIdx == rhsScalesIdx;
const bool isLhsScales = isScaled && operandIdx == lhsScalesIdx;
if (operandIdx == lhsIdx || isLhsScales) {
using MMAIntrinsicTy = decltype(mma.getIntrinsic());
const bool isScaled = std::is_same<MMAIntrinsicTy, ScaledMMAIntrinsic>::value;
const bool isLhs = isIntrinsicLhs<MMAIntrinsicTy>(operandIdx);
const bool isRhs = isIntrinsicRhs<MMAIntrinsicTy>(operandIdx);
const bool isAcc = isIntrinsicAcc<MMAIntrinsicTy>(operandIdx);
const bool isLhsScale = isIntrinsicLhsScale<MMAIntrinsicTy>(operandIdx);
const bool isRhsScale = isIntrinsicRhsScale<MMAIntrinsicTy>(operandIdx);
Comment on lines +210 to +216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks much better than magic numbers to me.

if (isLhs || isLhsScale) {
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
// Unroll on K with interleaving, then on M.
if (mma.getIntrinsicsK() > 1) {
Expand All @@ -253,10 +225,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
// the unrolled scales with each vector load, so we need to interleave at
// the very last dimension for the scales. For the LHS, we load in blocks,
// so we don't need to interleave.
if (isLhsScales) {
if (isLhsScale) {
interleavingIdx = swizzle.expandShape[1].size() - 1;
}
if (!isScaled || isLhsScales) {
if (!isScaled || isLhsScale) {
interleave(swizzle, 1, interleavingIdx);
}
}
Expand All @@ -272,7 +244,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
mma.getSubgroupsM() * mma.getSubgroupsN());
expand(swizzle, 0, dim);
}
} else if (operandIdx == rhsIdx || isRhsScales) {
} else if (isRhs || isRhsScale) {
// B-matrix (RHS). Since the pack ops already took care of transposing B,
// source dimensions are N (index 0) and K (index 1).
// Unroll on K with interleaving, then on N.
Expand All @@ -282,10 +254,10 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
// Like with the LHS above, we want to interleave such that we load all
// the unrolled scales with each vector load.
if (isRhsScales) {
if (isRhsScale) {
interleavingIdx = swizzle.expandShape[1].size() - 1;
}
if (!isScaled || isRhsScales) {
if (!isScaled || isRhsScale) {
interleave(swizzle, 1, interleavingIdx);
}
}
Expand All @@ -295,7 +267,7 @@ static TileSwizzle getSwizzleImpl(MMAAttrTy mma, unsigned operandIdx) {
if (mma.getSubgroupsN() > 1) {
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsN()});
}
} else if (operandIdx == accIdx) {
} else if (isAcc) {
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
// 1). Unroll on N, then on M.
if (mma.getIntrinsicsN() > 1) {
Expand All @@ -319,9 +291,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledScaledMMAAttr scaledMma,
return getSwizzleImpl(scaledMma, operandIdx);
}

TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment) {
return getSwizzleImpl(mma, static_cast<unsigned>(fragment));
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, int operandIndex) {
return getSwizzleImpl(mma, operandIndex);
}

/// Remove the expanded dimensions for this index and update the permutation by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ SmallVector<int64_t> sliceSwizzledShape(
/// Returns the swizzle for the full data-tiled-mma tile, including all the
/// relevant unrolling and expansion factors.
Codegen::TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment);
int operandIndex);

/// Returns the swizzle for the full data-tiled-scaled-mma tile, including all
/// the relevant unrolling and expansion factors.
Expand Down
Loading
Loading