diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index a9b49448c1d0..d88950e1a011 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -134,11 +134,11 @@ unsigned getNumCTAs(Attribute layout); // len(shape) == rank. SmallVector getMatrixOrder(unsigned rank, bool rowMajor); -// Return the order that represents that the dot operand is in kMajor +// Return the order that represents that the dot operand is in kMinor // (contiguous in the inner dimension) or it's contiguous on the outer // dimension. SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, - bool kMajor); + bool kMinor); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c3bf66def421..f2248f1e8cd7 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -785,6 +785,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "getSizePerThreadForOperand", (ins "int":$opIdx, "int":$kWidth)>, + + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, ]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a49b793e3044..16d717784234 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -249,15 +249,15 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { } SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, - bool kMajor) { - // kMajor: if true, the matrix is fastest-running on k, + bool kMinor) { + // kMinor: if true, the matrix is fastest-running on k, // otherwise it is on m (resp. n) // opIdx=0: [batch, m, k] if rank == 3 else [m, k] // opIdx=1: [batch, k, n] if rank == 3 else [k, n] // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - auto rowMajor = bool(opIdx) != kMajor; + auto rowMajor = bool(opIdx) != kMinor; return getMatrixOrder(rank, rowMajor); } @@ -290,7 +290,7 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); - return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); + return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMinor*/ true); } if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); @@ -1034,7 +1034,7 @@ SmallVector DotOperandEncodingAttr::getWarpOrder() const { } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), - /*kMajor*/ true); + /*kMinor*/ true); } SmallVector DotOperandEncodingAttr::getShapePerCTATile( ArrayRef tensorShape) const { @@ -1652,7 +1652,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { } SmallVector AMDMfmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true); } SmallVector @@ -1739,8 +1746,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { return shapePerCTATile; } SmallVector AMDWmmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true); } + SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1948,7 +1963,7 @@ NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); - return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); + return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true); } SmallVector @@ -2028,7 +2043,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getRepOrder() const { - if (auto mma = mlir::dyn_cast(getParent())) { + if (auto mma = mlir::dyn_cast(getParent())) { return mma.getRepOrderForOperand(getOpIdx()); } llvm::report_fatal_error( diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 43c87af487a1..bc6541f16b0e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -875,16 +875,16 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); - // The A and B operands are tiled in a kMajor fashion - auto kMajorOrder = dot.getRepOrder(); - assert(kMajorOrder == - getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + // The A and B operands are tiled in a kMinor fashion + auto kMinorOrder = dot.getRepOrder(); + assert(kMinorOrder == + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMinor=*/true)); - auto kMajorDims = - permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder); + auto kMinorDims = + permuteDimNames(standardOutDimNames(ctx, rank), kMinorOrder); // This agrees with the order of the elements, which means that we can share // the code below for both A and B without having to perform any swaps - assert(getOrder(dot) == kMajorOrder); + assert(getOrder(dot) == kMinorOrder); std::vector> registers; std::vector> lanes; @@ -911,7 +911,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, registers.push_back({i, 0}); LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}}, - ArrayRef(kMajorDims).take_front(2)); + ArrayRef(kMinorDims).take_front(2)); // Let warpsPerCTAMma = {2, 2}, then // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB @@ -952,7 +952,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, } } - ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); + ctaLayout *= LinearLayout({{S("warp"), warps}}, kMinorDims); return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 03b7c56b7e6b..51cb7a04faf4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -76,7 +76,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, return base; } -bool isKMajor(llvm::ArrayRef order, int opIdx) { +bool isKMinor(llvm::ArrayRef order, int opIdx) { auto rank = order.size(); int kdim = opIdx == 0 ? rank - 1 : rank - 2; return order[0] == kdim; @@ -106,9 +106,9 @@ bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout, const auto swizzleSlowDimSize = sharedLayout.getMaxPhase() * sharedLayout.getPerPhase(); const auto swizzlePatternSizeK = - isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; const auto swizzlePatternSizeNonK = - !isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + !isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1]; const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h index 0db193e1c102..0d57ee2a4f56 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h @@ -37,7 +37,7 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc, Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, const SharedMemoryObject &smemObj); -bool isKMajor(llvm::ArrayRef order, int opIdx); +bool isKMinor(llvm::ArrayRef order, int opIdx); using computeTensorElemMappingInBlockT = std::function>( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 9043090802bf..7d32b7253759 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -279,7 +279,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector offsets; Value smemBase; bool isFastPath = - !AMD::isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout); + !AMD::isKMinor(order, opIdx) && !hasSwizzleEnabled(sharedLayout); if (isFastPath) { // fast path handles tensors that are not k-major and have swizzling // disabled, in which case offsets computation can be simplified diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 7b7ca7d1e238..6a79568c4175 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -417,13 +417,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, // getValuesFromDotOperandLayoutStruct as both a and b are K-major assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), aShapePerCTA.size(), - /*kMajor=*/true)); + /*kMinor=*/true)); auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), bShapePerCTA.size(), - /*kMajor=*/true)); + /*kMinor=*/true)); auto hb = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy);