Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Implement RepOrder for AMD MMA layouts and change kMajor notation to kMinor #5126

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ unsigned getNumCTAs(Attribute layout);
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// Return the order that represents that the dot operand is in kMinor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);
bool kMinor);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
"getSizePerThreadForOperand",
(ins "int":$opIdx,
"int":$kWidth)>,

InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrderForOperand",
(ins "int":$opIdx)>,
];
}

Expand Down
33 changes: 24 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
bool kMinor) {
// kMinor: if true, the matrix is fastest-running on k,
// otherwise it is on m (resp. n)
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
auto rowMajor = bool(opIdx) != kMajor;
auto rowMajor = bool(opIdx) != kMinor;
return getMatrixOrder(rank, rowMajor);
}

Expand Down Expand Up @@ -290,7 +290,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMinor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
Expand Down Expand Up @@ -1034,7 +1034,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
/*kMinor*/ true);
}
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
ArrayRef<int64_t> tensorShape) const {
Expand Down Expand Up @@ -1652,7 +1652,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
}

SmallVector<int64_t>
Expand Down Expand Up @@ -1739,8 +1746,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -1948,7 +1963,7 @@ NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
}

SmallVector<int64_t>
Expand Down Expand Up @@ -2028,7 +2043,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
// DotOperand Encoding
//===----------------------------------------------------------------------===//
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
}
llvm::report_fatal_error(
Expand Down
18 changes: 9 additions & 9 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,16 +875,16 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder = dot.getRepOrder();
assert(kMajorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));
// The A and B operands are tiled in a kMinor fashion
auto kMinorOrder = dot.getRepOrder();
assert(kMinorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMinor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
auto kMinorDims =
permuteDimNames(standardOutDimNames(ctx, rank), kMinorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);
assert(getOrder(dot) == kMinorOrder);

std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
Expand All @@ -911,7 +911,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
registers.push_back({i, 0});

LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));
ArrayRef(kMinorDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
Expand Down Expand Up @@ -952,7 +952,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
}
}

ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMinorDims);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
return base;
}

bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx) {
bool isKMinor(llvm::ArrayRef<unsigned> order, int opIdx) {
auto rank = order.size();
int kdim = opIdx == 0 ? rank - 1 : rank - 2;
return order[0] == kdim;
Expand Down Expand Up @@ -106,9 +106,9 @@ bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout,
const auto swizzleSlowDimSize =
sharedLayout.getMaxPhase() * sharedLayout.getPerPhase();
const auto swizzlePatternSizeK =
isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
const auto swizzlePatternSizeNonK =
!isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
!isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;

const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1];
const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
const SharedMemoryObject &smemObj);

bool isKMajor(llvm::ArrayRef<unsigned> order, int opIdx);
bool isKMinor(llvm::ArrayRef<unsigned> order, int opIdx);

using computeTensorElemMappingInBlockT =
std::function<llvm::SmallVector<llvm::SmallVector<Value>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
SmallVector<Value> offsets;
Value smemBase;
bool isFastPath =
!AMD::isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
!AMD::isKMinor(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
if (isFastPath) {
// fast path handles tensors that are not k-major and have swizzling
// disabled, in which case offsets computation can be simplified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
// getValuesFromDotOperandLayoutStruct as both a and b are K-major
assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(),
aShapePerCTA.size(),
/*kMajor=*/true));
/*kMinor=*/true));
auto ha = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy);

assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(),
bShapePerCTA.size(),
/*kMajor=*/true));
/*kMinor=*/true));
auto hb = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy);

Expand Down
Loading