1212#include " mlir/Dialect/Math/Transforms/Passes.h"
1313#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
1414#include " mlir/Dialect/SCF/IR/SCF.h"
15+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
1516#include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1617#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1718#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -183,6 +184,321 @@ struct SetMulAddFMF final : OpRewritePattern<vector::MultiDimReductionOp> {
183184 }
184185};
185186
187+ // Rewrites vector.contracts into a chain of math.fma ops when possible.
188+ // Starting from the innermost position of the reduction dimension,
189+ // the lowering emits a single nested FMA chain as follows:
190+ // fma(a0 ,b0, fma(a1, b1, fma(a2, b2, fma(a3, b3, acc))))
191+ // where ai and bi are the elements extracted from lhs and rhs vectors
192+ // respectively along the reduction dimension.
193+ //
194+ // Example:
195+ // ```mlir
196+ // #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
197+ // #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
198+ // vector.contract
199+ // {
200+ // indexing_maps = [#map, #map, #map1],
201+ // iterator_types = ["parallel", "parallel", "reduction"],
202+ // kind = #vector.kind<add>
203+ // }
204+ // %arg0, %arg1, %cst : vector<2x1x8xf16>, vector<2x1x8xf16> into
205+ // vector<2x1xf16>
206+ // ```
207+ //
208+ // ==>
209+ // <Extract lhs/rhs along reduction dim> then:
210+ // ```mlir
211+ // %34 = math.fma %32, %33, %cst : vector<2xf16>
212+ // %37 = math.fma %35, %36, %34 : vector<2xf16>
213+ // %40 = math.fma %38, %39, %37 : vector<2xf16>
214+ // %43 = math.fma %41, %42, %40 : vector<2xf16>
215+ // %45 = math.fma %44, %45, %43 : vector<2xf16>
216+ // %49 = math.fma %47, %48, %46 : vector<2xf16>
217+ // %52 = math.fma %50, %51, %49 : vector<2xf16>
218+ // %55 = math.fma %53, %54, %52 : vector<2xf16>
219+ // ```
220+ //
221+ // Previously, contracts of the same form lowered to elementwise multiplies
222+ // followed by a vector.reduce. This lowering elides the need to reduce the
223+ // result of the elementwise operations separately and instead accumulates
224+ // directly result via FMAs, offering more profitable instruction level
225+ // scheduling on GPUs.
226+ struct ContractToChainFMA final : OpRewritePattern<vector::ContractionOp> {
227+ using Base::Base;
228+
229+ LogicalResult matchAndRewrite (vector::ContractionOp op,
230+ PatternRewriter &rewriter) const override {
231+ // TODO: Add a rewrite to support relevant contractions nested in
232+ // vector.mask.
233+ if (op.isMasked () || op.getKind () != vector::CombiningKind::ADD) {
234+ return failure ();
235+ }
236+
237+ VectorType lhsVecType = op.getLhsType ();
238+ VectorType rhsVecType = op.getRhsType ();
239+ if (lhsVecType.isScalable () || rhsVecType.isScalable ()) {
240+ return failure ();
241+ }
242+
243+ auto resultVecType = dyn_cast<VectorType>(op.getResultType ());
244+ if (!resultVecType || resultVecType.isScalable ()) {
245+ return failure ();
246+ }
247+
248+ auto maybeAccVecType = dyn_cast<VectorType>(op.getAccType ());
249+ if (maybeAccVecType && maybeAccVecType.isScalable ()) {
250+ return failure ();
251+ }
252+
253+ if (!isa<FloatType>(lhsVecType.getElementType ())) {
254+ return failure ();
255+ }
256+
257+ SmallVector<int64_t > redDims, parDims;
258+ getReductionAndParallelLoopDims (op.getIteratorTypes (), redDims, parDims);
259+ if (redDims.empty ()) {
260+ return failure ();
261+ }
262+
263+ auto elemType = getElementTypeOrSelf (op.getAccType ());
264+
265+ Location loc = op.getLoc ();
266+ Value lhs = op.getLhs ();
267+ Value rhs = op.getRhs ();
268+
269+ if (lhsVecType.getElementType () != elemType) {
270+ Type promotedType = lhsVecType.clone (elemType);
271+ lhs = arith::ExtFOp::create (rewriter, loc, promotedType, lhs);
272+ lhsVecType = cast<VectorType>(lhs.getType ());
273+ }
274+
275+ if (rhsVecType.getElementType () != elemType) {
276+ Type promotedType = rhsVecType.clone (elemType);
277+ rhs = arith::ExtFOp::create (rewriter, loc, promotedType, rhs);
278+ rhsVecType = cast<VectorType>(rhs.getType ());
279+ }
280+
281+ // New indices: [reduction..., parallel...].
282+ auto indices = llvm::to_vector (llvm::concat<int64_t >(redDims, parDims));
283+
284+ ArrayRef<int64_t > lhsShape = lhsVecType.getShape ();
285+ ArrayRef<int64_t > rhsShape = rhsVecType.getShape ();
286+ SmallVector<AffineMap, 4 > maps = op.getIndexingMapsArray ();
287+ AffineMap lhsMap = maps[0 ];
288+ AffineMap rhsMap = maps[1 ];
289+ AffineMap accMap = maps[2 ];
290+
291+ // Broadcast operands for missing parallel dimensions.
292+ unsigned numParallelDims = accMap.getNumResults ();
293+
294+ SmallVector<int64_t > lhsTranspose, rhsTranspose;
295+ lhs = broadcastMissingDims (
296+ rewriter, loc, lhsMap, accMap, op.getIteratorTypes (), numParallelDims,
297+ resultVecType, lhs, lhsShape, elemType, lhsTranspose);
298+ rhs = broadcastMissingDims (
299+ rewriter, loc, rhsMap, accMap, op.getIteratorTypes (), numParallelDims,
300+ resultVecType, rhs, rhsShape, elemType, rhsTranspose);
301+
302+ // Apply transposes to get [reduction..., parallel...] layout.
303+ lhs = vector::TransposeOp::create (rewriter, loc, lhs, lhsTranspose);
304+ rhs = vector::TransposeOp::create (rewriter, loc, rhs, rhsTranspose);
305+
306+ SmallVector<int64_t > accPerm;
307+ if (maybeAccVecType) {
308+ accPerm = getPermutationFromIndexingMap (maps[2 ], parDims);
309+ }
310+
311+ const size_t numRed = redDims.size ();
312+ auto lhsTransposedVecType = cast<VectorType>(lhs.getType ());
313+ int64_t lhsRedSize = productOfDims (lhsTransposedVecType, 0 , numRed);
314+ int64_t lhsParSize = productOfDims (lhsTransposedVecType, numRed,
315+ lhsTransposedVecType.getRank ());
316+
317+ // Shape-cast operands to 2D {reduction_size, parallel_size}.
318+ int64_t redSize = lhsRedSize;
319+ int64_t parSize = lhsParSize;
320+ auto flattened2DType = VectorType::get ({redSize, parSize}, elemType);
321+ Value lhs2D =
322+ vector::ShapeCastOp::create (rewriter, loc, flattened2DType, lhs);
323+ Value rhs2D =
324+ vector::ShapeCastOp::create (rewriter, loc, flattened2DType, rhs);
325+
326+ Value flattenedAcc;
327+ auto flatAccVecType = VectorType::get ({parSize}, elemType);
328+ VectorType preFlattenVecType = maybeAccVecType;
329+
330+ if (maybeAccVecType) {
331+ Value acc = op.getAcc ();
332+
333+ if (!isIdentityPermutation (accPerm)) {
334+ acc = vector::TransposeOp::create (rewriter, loc, acc, accPerm);
335+ preFlattenVecType = cast<VectorType>(acc.getType ());
336+ }
337+
338+ flattenedAcc =
339+ vector::ShapeCastOp::create (rewriter, loc, flatAccVecType, acc);
340+ } else {
341+ flattenedAcc = vector::BroadcastOp::create (rewriter, loc, flatAccVecType,
342+ op.getAcc ());
343+ }
344+
345+ Value resultFlat =
346+ buildFMAChain (rewriter, loc, lhs2D, rhs2D, flattenedAcc, redSize);
347+
348+ // Restore result to original form.
349+ Value result;
350+ if (maybeAccVecType) {
351+ Value reshaped = vector::ShapeCastOp::create (
352+ rewriter, loc, preFlattenVecType, resultFlat);
353+
354+ if (!isIdentityPermutation (accPerm)) {
355+ result = vector::TransposeOp::create (rewriter, loc, maybeAccVecType,
356+ reshaped, invert (accPerm));
357+ } else {
358+ result = reshaped;
359+ }
360+
361+ } else {
362+ result = vector::ExtractOp::create (rewriter, loc, resultFlat, 0 );
363+ }
364+
365+ rewriter.replaceOp (op, result);
366+ return success ();
367+ }
368+
369+ private:
370+ static Value broadcastMissingDims (
371+ PatternRewriter &rewriter, Location loc, AffineMap operandMap,
372+ AffineMap accMap, ArrayAttr iteratorTypes, unsigned numParallelDims,
373+ VectorType resultType, Value operand, ArrayRef<int64_t > operandShape,
374+ Type elemType, SmallVectorImpl<int64_t > &transpose) {
375+ SmallVector<int64_t > reductionDims =
376+ getReductionIndex (operandMap, iteratorTypes);
377+
378+ unsigned numDimToBroadcast =
379+ numParallelDims - (operandMap.getNumResults () - reductionDims.size ());
380+
381+ SmallVector<int64_t > broadcastDims;
382+
383+ for (int64_t dim : reductionDims) {
384+ transpose.push_back (numDimToBroadcast + dim);
385+ }
386+
387+ for (unsigned i = 0 ; i < numParallelDims; ++i) {
388+ unsigned iterDim = accMap.getDimPosition (i);
389+
390+ std::optional<unsigned > opDim = getDimPosition (operandMap, iterDim);
391+ if (opDim) {
392+ transpose.push_back (numDimToBroadcast + *opDim);
393+ } else {
394+ broadcastDims.push_back (resultType.getDimSize (i));
395+ transpose.push_back (broadcastDims.size () - 1 );
396+ }
397+ }
398+
399+ Value result = operand;
400+ if (!broadcastDims.empty ()) {
401+ llvm::append_range (broadcastDims, operandShape);
402+ auto expandedType = VectorType::get (broadcastDims, elemType);
403+ result = vector::BroadcastOp::create (rewriter, loc, expandedType, result);
404+ }
405+
406+ return result;
407+ }
408+
409+ static std::optional<unsigned > getDimPosition (AffineMap map, unsigned dim) {
410+ for (unsigned i = 0 , e = map.getNumResults (); i < e; i++) {
411+ if (map.getDimPosition (i) == dim)
412+ return i;
413+ }
414+ return std::nullopt ;
415+ }
416+
417+ static SmallVector<int64_t > getReductionIndex (AffineMap map,
418+ ArrayAttr iteratorTypes) {
419+ SmallVector<int64_t > dimsIdx;
420+ for (unsigned i = 0 , e = map.getNumResults (); i < e; i++) {
421+ if (vector::isReductionIterator (iteratorTypes[map.getDimPosition (i)]))
422+ dimsIdx.push_back (i);
423+ }
424+ return dimsIdx;
425+ }
426+
427+ static SmallVector<int64_t > invert (ArrayRef<int64_t > perm) {
428+ SmallVector<int64_t > inv (perm.size ());
429+ for (auto [i, p] : llvm::enumerate (perm)) {
430+ inv[p] = i;
431+ }
432+ return inv;
433+ }
434+
435+ static void getReductionAndParallelLoopDims (ArrayAttr iters,
436+ SmallVectorImpl<int64_t > &red,
437+ SmallVectorImpl<int64_t > &par) {
438+ for (auto [idx, attr] : llvm::enumerate (iters)) {
439+ if (vector::isReductionIterator (attr)) {
440+ red.push_back (idx);
441+ } else {
442+ par.push_back (idx);
443+ }
444+ }
445+ }
446+
447+ // / Constructs a permutation for vector.transpose from an affine map and a
448+ // / reordered list of dimension.
449+ // /
450+ // / Example:
451+ // / map: (d0, d1, d2) -> (d0, d2, d1)
452+ // / iterator_types = ["parallel","parallel","reduction"]
453+ // ==> new dim order: [2, 0, 1]
454+ // /
455+ // / Step 1: Build dim-to-result mapping from the map.
456+ // / dimToRes = [0, 2, 1] i.e {0: 0, 1: 2, 2: 1}
457+ // /
458+ // / Step 2: Walk new dimension order in order to build permutation.
459+ // / indices[0]=2 -> dimToRes[2]=1
460+ // / indices[1]=0 -> dimToRes[0]=0
461+ // / indices[2]=1 -> dimToRes[1]=2
462+ // /
463+ // / Result: perm = [1, 0, 2]
464+ static SmallVector<int64_t >
465+ getPermutationFromIndexingMap (AffineMap map, ArrayRef<int64_t > indices) {
466+ SmallVector<int64_t > dimToRes (map.getNumDims ());
467+ for (int res = 0 , e = map.getNumResults (); res != e; ++res) {
468+ dimToRes[map.getDimPosition (res)] = res;
469+ }
470+
471+ return to_vector (
472+ llvm::map_range (indices, [&](int64_t i) { return dimToRes[i]; }));
473+ }
474+
475+ static int64_t productOfDims (VectorType vt, unsigned lo, unsigned hi) {
476+ int64_t p = 1 ;
477+ for (unsigned i = lo; i < hi; ++i) {
478+ p *= vt.getDimSize (i);
479+ }
480+ return p;
481+ }
482+
483+ static bool isIdentityPermutation (ArrayRef<int64_t > perm) {
484+ return llvm::all_of (llvm::enumerate (perm),
485+ [](auto p) { return p.value () == p.index (); });
486+ }
487+
488+ static Value buildFMAChain (PatternRewriter &rewriter, Location loc,
489+ Value lhs2D, Value rhs2D, Value accFlat,
490+ int64_t K) {
491+ Value current = accFlat;
492+
493+ for (int64_t k = K - 1 ; k >= 0 ; --k) {
494+ Value a = vector::ExtractOp::create (rewriter, loc, lhs2D, k);
495+ Value b = vector::ExtractOp::create (rewriter, loc, rhs2D, k);
496+ current = math::FmaOp::create (rewriter, loc, a, b, current);
497+ }
498+ return current;
499+ }
500+ };
501+
186502struct LLVMGPUVectorLoweringPass final
187503 : impl::LLVMGPUVectorLoweringPassBase<LLVMGPUVectorLoweringPass> {
188504 void getDependentDialects (DialectRegistry ®istry) const override {
@@ -206,6 +522,14 @@ struct LLVMGPUVectorLoweringPass final
206522 }
207523 }
208524
525+ {
526+ RewritePatternSet patterns (ctx);
527+ vector::populateVectorReductionToContractPatterns (patterns);
528+ if (failed (applyPatternsGreedily (funcOp, std::move (patterns)))) {
529+ return signalPassFailure ();
530+ }
531+ }
532+
209533 {
210534 // Lower high level vector operations like contract or multidim reduce ops
211535 // to lower level vector ops.
@@ -222,6 +546,8 @@ struct LLVMGPUVectorLoweringPass final
222546 contractLoweringPatterns, options.vectorContractLowering );
223547 contractLoweringPatterns.add <PromoteContractOperands>(
224548 funcOp->getContext ());
549+ contractLoweringPatterns.add <ContractToChainFMA>(funcOp->getContext (),
550+ PatternBenefit (2 ));
225551 vector::populateVectorGatherLoweringPatterns (contractLoweringPatterns);
226552 vector::populateVectorMaskOpLoweringPatterns (contractLoweringPatterns);
227553 vector::populateVectorShapeCastLoweringPatterns (contractLoweringPatterns);
0 commit comments