Skip to content

Commit 6da6e84

Browse files
Merge commit '9f939760d2455bb0644698a5b6f3a13aa485abde'
2 parents d96a80e + 9f93976 commit 6da6e84

File tree

4 files changed

+39
-32
lines changed

4 files changed

+39
-32
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
88
#include "triton/Analysis/Utility.h"
99
#include "triton/Dialect/Triton/IR/Dialect.h"
10+
#include "triton/Dialect/Triton/IR/Utility.h"
1011
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1112
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1213
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -77,28 +78,33 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
7778
}
7879
}
7980

80-
SmallVector<unsigned> ret(rank, 1);
81-
SmallVector<int64_t> shapePerWarp(rank, 1);
82-
shapePerWarp[rank - 1] = 8;
83-
shapePerWarp[rank - 2] = 16;
84-
// TODO (@daadaada): double-check.
85-
// original logic in
86-
// https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252
87-
// seems buggy for shape = [32, 16] ?
88-
do {
89-
if (ret[0] * ret[1] >= numWarps)
90-
break;
91-
if (shape[0] / shapePerWarp[0] / ret[0] >=
92-
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
93-
if (ret[0] < shape[0] / shapePerWarp[0]) {
94-
ret[0] *= 2;
95-
} else
96-
ret[1] *= 2;
81+
assert(rank == 2);
82+
SmallVector<int64_t> shapePerWarp = {16, 8};
83+
SmallVector<int64_t> warps = {1, 1};
84+
// Compute repM and repN
85+
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
86+
ceil(shape[1], shapePerWarp[1])};
87+
// The formula for the number of registers given the reps is
88+
// repM * 4 * repK + repN * 2 * repK + regsC
89+
// where regsC = repM * repN * 4, which does not depend on the warp shape
90+
//
91+
// As such, to minimize the register pressure, we need to balance
92+
// repM and repN. We then untie towards M, as the lhs tile has 4 elements,
93+
// and the rhs tile has just 2.
94+
while (product(warps) < numWarps) {
95+
if (reps[0] >= reps[1]) {
96+
warps[0] *= 2;
97+
// Too many warps for this mma (repM == repN == 1).
98+
// We allocate the remainin warps to the left (arbitrary choice)
99+
if (reps[0] != 1) {
100+
reps[0] /= 2;
101+
}
97102
} else {
98-
ret[1] *= 2;
103+
warps[1] *= 2;
104+
reps[1] /= 2;
99105
}
100-
} while (true);
101-
return ret;
106+
}
107+
return {(unsigned)warps[0], (unsigned)warps[1]};
102108
}
103109

104110
SmallVector<unsigned, 2>

python/test/unit/runtime/test_subproc.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from triton.compiler import ASTSource
88

99
target = triton.runtime.driver.active.get_current_target()
10+
start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn'
1011

1112

1213
def compile_fn(attrs):
@@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2728

2829
def test_compile_in_subproc() -> None:
2930
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
30-
multiprocessing.set_start_method('fork')
31-
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
31+
mp_ctx = multiprocessing.get_context(start_method)
32+
proc = mp_ctx.Process(target=compile_fn, args=(config, ))
3233
proc.start()
3334
proc.join()
3435
assert proc.exitcode == 0
@@ -49,8 +50,8 @@ def kernel_dot(Z):
4950

5051
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
5152
config = AttrsDescriptor.from_hints({0: 16})
52-
assert multiprocessing.get_start_method() == 'fork'
53-
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
53+
mp_ctx = multiprocessing.get_context(start_method)
54+
proc = mp_ctx.Process(target=compile_fn_dot, args=(config, ))
5455
proc.start()
5556
proc.join()
5657
assert proc.exitcode == 0
@@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
9293

9394
# stage 2.p
9495
shutil.rmtree(fresh_triton_cache)
95-
assert multiprocessing.get_start_method() == 'fork'
96-
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))
96+
mp_ctx = multiprocessing.get_context(start_method)
97+
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, ))
9798

9899
# stage 3.c
99100
proc.start()

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
7373

7474
// -----
7575

76-
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
76+
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
7777
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
7878
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
7979
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
@@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :
9393

9494
// -----
9595

96-
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
96+
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
9797
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>
9898

9999
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
@@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
148148
// -----
149149

150150
// Verify that we use mmav2 when the k dim is too small for mmav3.
151-
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}>
151+
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}>
152152
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
153153
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
154154
// CHECK-LABEL: small_k_size

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
659659

660660
int kWidth = encoding.getKWidth();
661661
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(
662-
shapePerCTA, bitwidth, kWidth, encoding.getOpIdx());
662+
shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx());
663663

664664
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
665-
auto order = triton::gpu::getOrder(mmaLayout);
665+
auto warpOrder = mmaLayout.getWarpOrder();
666666
Value warp = udiv(thread, i32_val(32));
667667
Value lane = urem(thread, i32_val(32));
668668

669669
SmallVector<Value> multiDimWarpId =
670-
delinearize(rewriter, loc, warp, warpsPerCTA, order);
670+
delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder);
671671
Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0]));
672672
int warpsPerTile;
673673
Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16));

0 commit comments

Comments
 (0)