Skip to content

Commit b98c1b9

Browse files
authored
[LLVMGPU][Codegen] Emit packed chain FMA from select multi_reductions and contracts (#21855)
This patch teaches the vector lowering pipeline to: 1. Rewrite `vector.multi_reduction<add>` whose input is `arith.mulf` into a `vector.contract` via `vector::populateVectorReductionToContractPatterns` 2. Lower a restricted set of `vector.contract` into packed FMA chains. Previously lowering `vector.multi_reduction` of the same form produced elementwise pack-muls per K-slice and then reduced them with a left-associated, serial chain of `v_add_f{16, 32}` `(mul(a0 ,b0) + (mul(a1, b1) + … + acc` The new lowering emits a single nested FMA chain and folds the accumulation into the `math.fma` c-operand `fma(a0 ,b0, fma(a1, b1, fma(a2, b2, fma(a3, b3, acc))))` To do this, we first permute the reduction and parallel dimensions of the `LHS` and `RHS` to the order of `[reduction, ..., parallel, ...]`. The `LHS` and `RHS` are then collapsed to a 2D shape of `{Π reduction dimensions, Π parallel dimensions}`. Then we form the FMA chain by iterating backwards, seeded by the accumulator. Not all forms of `vector.contract` are suitable in the current approach. For example, when an operand drops a parallel iterator as in matmul-like contracts. We require both sides to share the same 2D tuple. Unsupported cases fall back to the existing lowering. Fixes: #21483 (variant of original issue; for [issue #21513](#21513)). --------- Signed-off-by: Eric Feng <[email protected]>
1 parent df3d076 commit b98c1b9

File tree

4 files changed

+442
-4
lines changed

4 files changed

+442
-4
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Math/Transforms/Passes.h"
1313
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1414
#include "mlir/Dialect/SCF/IR/SCF.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1617
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1718
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -183,6 +184,321 @@ struct SetMulAddFMF final : OpRewritePattern<vector::MultiDimReductionOp> {
183184
}
184185
};
185186

187+
// Rewrites vector.contracts into a chain of math.fma ops when possible.
188+
// Starting from the innermost position of the reduction dimension,
189+
// the lowering emits a single nested FMA chain as follows:
190+
// fma(a0 ,b0, fma(a1, b1, fma(a2, b2, fma(a3, b3, acc))))
191+
// where ai and bi are the elements extracted from lhs and rhs vectors
192+
// respectively along the reduction dimension.
193+
//
194+
// Example:
195+
// ```mlir
196+
// #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
197+
// #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
198+
// vector.contract
199+
//{
200+
// indexing_maps = [#map, #map, #map1],
201+
// iterator_types = ["parallel", "parallel", "reduction"],
202+
// kind = #vector.kind<add>
203+
// }
204+
// %arg0, %arg1, %cst : vector<2x1x8xf16>, vector<2x1x8xf16> into
205+
// vector<2x1xf16>
206+
// ```
207+
//
208+
// ==>
209+
// <Extract lhs/rhs along reduction dim> then:
210+
// ```mlir
211+
// %34 = math.fma %32, %33, %cst : vector<2xf16>
212+
// %37 = math.fma %35, %36, %34 : vector<2xf16>
213+
// %40 = math.fma %38, %39, %37 : vector<2xf16>
214+
// %43 = math.fma %41, %42, %40 : vector<2xf16>
215+
// %45 = math.fma %44, %45, %43 : vector<2xf16>
216+
// %49 = math.fma %47, %48, %46 : vector<2xf16>
217+
// %52 = math.fma %50, %51, %49 : vector<2xf16>
218+
// %55 = math.fma %53, %54, %52 : vector<2xf16>
219+
// ```
220+
//
221+
// Previously, contracts of the same form lowered to elementwise multiplies
222+
// followed by a vector.reduce. This lowering elides the need to reduce the
223+
// result of the elementwise operations separately and instead accumulates
224+
// directly result via FMAs, offering more profitable instruction level
225+
// scheduling on GPUs.
226+
struct ContractToChainFMA final : OpRewritePattern<vector::ContractionOp> {
227+
using Base::Base;
228+
229+
LogicalResult matchAndRewrite(vector::ContractionOp op,
230+
PatternRewriter &rewriter) const override {
231+
// TODO: Add a rewrite to support relevant contractions nested in
232+
// vector.mask.
233+
if (op.isMasked() || op.getKind() != vector::CombiningKind::ADD) {
234+
return failure();
235+
}
236+
237+
VectorType lhsVecType = op.getLhsType();
238+
VectorType rhsVecType = op.getRhsType();
239+
if (lhsVecType.isScalable() || rhsVecType.isScalable()) {
240+
return failure();
241+
}
242+
243+
auto resultVecType = dyn_cast<VectorType>(op.getResultType());
244+
if (!resultVecType || resultVecType.isScalable()) {
245+
return failure();
246+
}
247+
248+
auto maybeAccVecType = dyn_cast<VectorType>(op.getAccType());
249+
if (maybeAccVecType && maybeAccVecType.isScalable()) {
250+
return failure();
251+
}
252+
253+
if (!isa<FloatType>(lhsVecType.getElementType())) {
254+
return failure();
255+
}
256+
257+
SmallVector<int64_t> redDims, parDims;
258+
getReductionAndParallelLoopDims(op.getIteratorTypes(), redDims, parDims);
259+
if (redDims.empty()) {
260+
return failure();
261+
}
262+
263+
auto elemType = getElementTypeOrSelf(op.getAccType());
264+
265+
Location loc = op.getLoc();
266+
Value lhs = op.getLhs();
267+
Value rhs = op.getRhs();
268+
269+
if (lhsVecType.getElementType() != elemType) {
270+
Type promotedType = lhsVecType.clone(elemType);
271+
lhs = arith::ExtFOp::create(rewriter, loc, promotedType, lhs);
272+
lhsVecType = cast<VectorType>(lhs.getType());
273+
}
274+
275+
if (rhsVecType.getElementType() != elemType) {
276+
Type promotedType = rhsVecType.clone(elemType);
277+
rhs = arith::ExtFOp::create(rewriter, loc, promotedType, rhs);
278+
rhsVecType = cast<VectorType>(rhs.getType());
279+
}
280+
281+
// New indices: [reduction..., parallel...].
282+
auto indices = llvm::to_vector(llvm::concat<int64_t>(redDims, parDims));
283+
284+
ArrayRef<int64_t> lhsShape = lhsVecType.getShape();
285+
ArrayRef<int64_t> rhsShape = rhsVecType.getShape();
286+
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
287+
AffineMap lhsMap = maps[0];
288+
AffineMap rhsMap = maps[1];
289+
AffineMap accMap = maps[2];
290+
291+
// Broadcast operands for missing parallel dimensions.
292+
unsigned numParallelDims = accMap.getNumResults();
293+
294+
SmallVector<int64_t> lhsTranspose, rhsTranspose;
295+
lhs = broadcastMissingDims(
296+
rewriter, loc, lhsMap, accMap, op.getIteratorTypes(), numParallelDims,
297+
resultVecType, lhs, lhsShape, elemType, lhsTranspose);
298+
rhs = broadcastMissingDims(
299+
rewriter, loc, rhsMap, accMap, op.getIteratorTypes(), numParallelDims,
300+
resultVecType, rhs, rhsShape, elemType, rhsTranspose);
301+
302+
// Apply transposes to get [reduction..., parallel...] layout.
303+
lhs = vector::TransposeOp::create(rewriter, loc, lhs, lhsTranspose);
304+
rhs = vector::TransposeOp::create(rewriter, loc, rhs, rhsTranspose);
305+
306+
SmallVector<int64_t> accPerm;
307+
if (maybeAccVecType) {
308+
accPerm = getPermutationFromIndexingMap(maps[2], parDims);
309+
}
310+
311+
const size_t numRed = redDims.size();
312+
auto lhsTransposedVecType = cast<VectorType>(lhs.getType());
313+
int64_t lhsRedSize = productOfDims(lhsTransposedVecType, 0, numRed);
314+
int64_t lhsParSize = productOfDims(lhsTransposedVecType, numRed,
315+
lhsTransposedVecType.getRank());
316+
317+
// Shape-cast operands to 2D {reduction_size, parallel_size}.
318+
int64_t redSize = lhsRedSize;
319+
int64_t parSize = lhsParSize;
320+
auto flattened2DType = VectorType::get({redSize, parSize}, elemType);
321+
Value lhs2D =
322+
vector::ShapeCastOp::create(rewriter, loc, flattened2DType, lhs);
323+
Value rhs2D =
324+
vector::ShapeCastOp::create(rewriter, loc, flattened2DType, rhs);
325+
326+
Value flattenedAcc;
327+
auto flatAccVecType = VectorType::get({parSize}, elemType);
328+
VectorType preFlattenVecType = maybeAccVecType;
329+
330+
if (maybeAccVecType) {
331+
Value acc = op.getAcc();
332+
333+
if (!isIdentityPermutation(accPerm)) {
334+
acc = vector::TransposeOp::create(rewriter, loc, acc, accPerm);
335+
preFlattenVecType = cast<VectorType>(acc.getType());
336+
}
337+
338+
flattenedAcc =
339+
vector::ShapeCastOp::create(rewriter, loc, flatAccVecType, acc);
340+
} else {
341+
flattenedAcc = vector::BroadcastOp::create(rewriter, loc, flatAccVecType,
342+
op.getAcc());
343+
}
344+
345+
Value resultFlat =
346+
buildFMAChain(rewriter, loc, lhs2D, rhs2D, flattenedAcc, redSize);
347+
348+
// Restore result to original form.
349+
Value result;
350+
if (maybeAccVecType) {
351+
Value reshaped = vector::ShapeCastOp::create(
352+
rewriter, loc, preFlattenVecType, resultFlat);
353+
354+
if (!isIdentityPermutation(accPerm)) {
355+
result = vector::TransposeOp::create(rewriter, loc, maybeAccVecType,
356+
reshaped, invert(accPerm));
357+
} else {
358+
result = reshaped;
359+
}
360+
361+
} else {
362+
result = vector::ExtractOp::create(rewriter, loc, resultFlat, 0);
363+
}
364+
365+
rewriter.replaceOp(op, result);
366+
return success();
367+
}
368+
369+
private:
370+
static Value broadcastMissingDims(
371+
PatternRewriter &rewriter, Location loc, AffineMap operandMap,
372+
AffineMap accMap, ArrayAttr iteratorTypes, unsigned numParallelDims,
373+
VectorType resultType, Value operand, ArrayRef<int64_t> operandShape,
374+
Type elemType, SmallVectorImpl<int64_t> &transpose) {
375+
SmallVector<int64_t> reductionDims =
376+
getReductionIndex(operandMap, iteratorTypes);
377+
378+
unsigned numDimToBroadcast =
379+
numParallelDims - (operandMap.getNumResults() - reductionDims.size());
380+
381+
SmallVector<int64_t> broadcastDims;
382+
383+
for (int64_t dim : reductionDims) {
384+
transpose.push_back(numDimToBroadcast + dim);
385+
}
386+
387+
for (unsigned i = 0; i < numParallelDims; ++i) {
388+
unsigned iterDim = accMap.getDimPosition(i);
389+
390+
std::optional<unsigned> opDim = getDimPosition(operandMap, iterDim);
391+
if (opDim) {
392+
transpose.push_back(numDimToBroadcast + *opDim);
393+
} else {
394+
broadcastDims.push_back(resultType.getDimSize(i));
395+
transpose.push_back(broadcastDims.size() - 1);
396+
}
397+
}
398+
399+
Value result = operand;
400+
if (!broadcastDims.empty()) {
401+
llvm::append_range(broadcastDims, operandShape);
402+
auto expandedType = VectorType::get(broadcastDims, elemType);
403+
result = vector::BroadcastOp::create(rewriter, loc, expandedType, result);
404+
}
405+
406+
return result;
407+
}
408+
409+
static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
410+
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
411+
if (map.getDimPosition(i) == dim)
412+
return i;
413+
}
414+
return std::nullopt;
415+
}
416+
417+
static SmallVector<int64_t> getReductionIndex(AffineMap map,
418+
ArrayAttr iteratorTypes) {
419+
SmallVector<int64_t> dimsIdx;
420+
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
421+
if (vector::isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
422+
dimsIdx.push_back(i);
423+
}
424+
return dimsIdx;
425+
}
426+
427+
static SmallVector<int64_t> invert(ArrayRef<int64_t> perm) {
428+
SmallVector<int64_t> inv(perm.size());
429+
for (auto [i, p] : llvm::enumerate(perm)) {
430+
inv[p] = i;
431+
}
432+
return inv;
433+
}
434+
435+
static void getReductionAndParallelLoopDims(ArrayAttr iters,
436+
SmallVectorImpl<int64_t> &red,
437+
SmallVectorImpl<int64_t> &par) {
438+
for (auto [idx, attr] : llvm::enumerate(iters)) {
439+
if (vector::isReductionIterator(attr)) {
440+
red.push_back(idx);
441+
} else {
442+
par.push_back(idx);
443+
}
444+
}
445+
}
446+
447+
/// Constructs a permutation for vector.transpose from an affine map and a
448+
/// reordered list of dimension.
449+
///
450+
/// Example:
451+
/// map: (d0, d1, d2) -> (d0, d2, d1)
452+
/// iterator_types = ["parallel","parallel","reduction"]
453+
// ==> new dim order: [2, 0, 1]
454+
///
455+
/// Step 1: Build dim-to-result mapping from the map.
456+
/// dimToRes = [0, 2, 1] i.e {0: 0, 1: 2, 2: 1}
457+
///
458+
/// Step 2: Walk new dimension order in order to build permutation.
459+
/// indices[0]=2 -> dimToRes[2]=1
460+
/// indices[1]=0 -> dimToRes[0]=0
461+
/// indices[2]=1 -> dimToRes[1]=2
462+
///
463+
/// Result: perm = [1, 0, 2]
464+
static SmallVector<int64_t>
465+
getPermutationFromIndexingMap(AffineMap map, ArrayRef<int64_t> indices) {
466+
SmallVector<int64_t> dimToRes(map.getNumDims());
467+
for (int res = 0, e = map.getNumResults(); res != e; ++res) {
468+
dimToRes[map.getDimPosition(res)] = res;
469+
}
470+
471+
return to_vector(
472+
llvm::map_range(indices, [&](int64_t i) { return dimToRes[i]; }));
473+
}
474+
475+
static int64_t productOfDims(VectorType vt, unsigned lo, unsigned hi) {
476+
int64_t p = 1;
477+
for (unsigned i = lo; i < hi; ++i) {
478+
p *= vt.getDimSize(i);
479+
}
480+
return p;
481+
}
482+
483+
static bool isIdentityPermutation(ArrayRef<int64_t> perm) {
484+
return llvm::all_of(llvm::enumerate(perm),
485+
[](auto p) { return p.value() == p.index(); });
486+
}
487+
488+
static Value buildFMAChain(PatternRewriter &rewriter, Location loc,
489+
Value lhs2D, Value rhs2D, Value accFlat,
490+
int64_t K) {
491+
Value current = accFlat;
492+
493+
for (int64_t k = K - 1; k >= 0; --k) {
494+
Value a = vector::ExtractOp::create(rewriter, loc, lhs2D, k);
495+
Value b = vector::ExtractOp::create(rewriter, loc, rhs2D, k);
496+
current = math::FmaOp::create(rewriter, loc, a, b, current);
497+
}
498+
return current;
499+
}
500+
};
501+
186502
struct LLVMGPUVectorLoweringPass final
187503
: impl::LLVMGPUVectorLoweringPassBase<LLVMGPUVectorLoweringPass> {
188504
void getDependentDialects(DialectRegistry &registry) const override {
@@ -206,6 +522,14 @@ struct LLVMGPUVectorLoweringPass final
206522
}
207523
}
208524

525+
{
526+
RewritePatternSet patterns(ctx);
527+
vector::populateVectorReductionToContractPatterns(patterns);
528+
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
529+
return signalPassFailure();
530+
}
531+
}
532+
209533
{
210534
// Lower high level vector operations like contract or multidim reduce ops
211535
// to lower level vector ops.
@@ -222,6 +546,8 @@ struct LLVMGPUVectorLoweringPass final
222546
contractLoweringPatterns, options.vectorContractLowering);
223547
contractLoweringPatterns.add<PromoteContractOperands>(
224548
funcOp->getContext());
549+
contractLoweringPatterns.add<ContractToChainFMA>(funcOp->getContext(),
550+
PatternBenefit(2));
225551
vector::populateVectorGatherLoweringPatterns(contractLoweringPatterns);
226552
vector::populateVectorMaskOpLoweringPatterns(contractLoweringPatterns);
227553
vector::populateVectorShapeCastLoweringPatterns(contractLoweringPatterns);

compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ hal.executable @dot_dispatch_0 {
8787
// CHECK: llvm.br
8888
// CHECK: llvm.load {{.*}} : !llvm.ptr<1> -> vector<32xf32>
8989
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<1> -> vector<16xf32>
90-
// CHECK-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
90+
// CHECK-COUNT-512: llvm.call @__nv_fmaf({{.*}}) : (f32, f32, f32) -> f32
9191
// CHECK: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<1>
9292

9393
// -----
@@ -151,7 +151,7 @@ hal.executable @dot_dispatch_0 {
151151
// CHECK-LABEL: hal.executable public @dot_dispatch_0
152152
// CHECK: hal.executable.variant public @cuda
153153
// CHECK: llvm.br
154-
// CHECK-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
154+
// CHECK-COUNT-512: llvm.call @__nv_fmaf({{.*}}) : (f32, f32, f32) -> f32
155155
// CHECK: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<1>
156156

157157
// -----

compiler/src/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ hal.executable @dot_dispatch_0 {
8888
// RDNA3: llvm.br
8989
// RDNA3-COUNT-1: llvm.load {{.*}} : !llvm.ptr<7> -> vector<32xf32>
9090
// RDNA3-COUNT-32: llvm.load {{.*}} : !llvm.ptr<7> -> vector<16xf32>
91-
// RDNA3-COUNT-32: llvm.intr.fmuladd({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
91+
// RDNA3-COUNT-32: llvm.intr.fma({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
9292
// RDNA3-COUNT-1: llvm.store {{.*}} : vector<16xf32>, !llvm.ptr<7>
9393
// RDNA3: llvm.br
9494

0 commit comments

Comments
 (0)