@@ -494,17 +494,17 @@ static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
494494// Rewrite patterns
495495// ===----------------------------------------------------------------------===//
496496
497- // This pattern fold `vector.extract` and `vector.splat ` into
497+ // This pattern fold `vector.extract` and `vector.broadcast ` into
498498// `aievec.broadcast` for AIE2
499499struct FoldVectorExtractAndSplatToAIEBroadcast
500- : OpConversionPattern<vector::SplatOp > {
500+ : OpConversionPattern<vector::BroadcastOp > {
501501 using OpConversionPattern::OpConversionPattern;
502502
503503 LogicalResult
504- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
504+ matchAndRewrite (vector::BroadcastOp bcastOp , OpAdaptor adaptor,
505505 ConversionPatternRewriter &rewriter) const override {
506506
507- auto extOp = adaptor.getInput ().getDefiningOp <vector::ExtractOp>();
507+ auto extOp = adaptor.getSource ().getDefiningOp <vector::ExtractOp>();
508508
509509 if (!extOp)
510510 return failure ();
@@ -513,7 +513,7 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
513513 auto pos = extOp.getStaticPosition ();
514514 int64_t posVal = pos[0 ];
515515 auto srcVecType = cast<VectorType>(src.getType ());
516- auto resultType = cast<VectorType>(splatOp .getResult ().getType ());
516+ auto resultType = cast<VectorType>(bcastOp .getResult ().getType ());
517517 if (srcVecType != resultType) {
518518 if (srcVecType.getNumElements () != 2 * resultType.getNumElements ())
519519 return failure ();
@@ -530,17 +530,17 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
530530 if (unsigned laneSize = getVectorLaneSize (resultType);
531531 laneSize * elWidth == 512 ) {
532532 // Common use case for the broadcast_elem intrinsic
533- rewriter.replaceOpWithNewOp <aievec::BroadcastOp>(splatOp , resultType, src,
533+ rewriter.replaceOpWithNewOp <aievec::BroadcastOp>(bcastOp , resultType, src,
534534 posVal);
535535 } else if (laneSize * elWidth == 256 ) {
536536 // e.g. need v16bf16 due to the subsequent v16accfloat operation
537537 VectorType aievecBcastType =
538538 createVectorType (512 / elWidth, resultType.getElementType ());
539539 auto concatOp = rewriter.create <aievec::ConcatOp>(
540- splatOp .getLoc (), aievecBcastType, SmallVector<Value>({src, src}));
540+ bcastOp .getLoc (), aievecBcastType, SmallVector<Value>({src, src}));
541541 auto aieBcastOp = rewriter.create <aievec::BroadcastOp>(
542- splatOp .getLoc (), aievecBcastType, concatOp.getResult (), posVal);
543- rewriter.replaceOpWithNewOp <aievec::ExtOp>(splatOp , resultType,
542+ bcastOp .getLoc (), aievecBcastType, concatOp.getResult (), posVal);
543+ rewriter.replaceOpWithNewOp <aievec::ExtOp>(bcastOp , resultType,
544544 aieBcastOp.getResult (), 0 );
545545 } else if (laneSize * elWidth == 1024 ) {
546546 // e.g. need v32int32 due to the subsequent v32acc32 operation
@@ -549,12 +549,12 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
549549 auto half = static_cast <int8_t >(posVal / resultType.getNumElements ());
550550 posVal -= half * resultType.getNumElements ();
551551 auto extOp =
552- rewriter.create <aievec::ExtOp>(splatOp .getLoc (), aievecBcastType, src,
552+ rewriter.create <aievec::ExtOp>(bcastOp .getLoc (), aievecBcastType, src,
553553 rewriter.getI8IntegerAttr (half));
554554 auto aieBcastOp = rewriter.create <aievec::BroadcastOp>(
555- splatOp .getLoc (), aievecBcastType, extOp.getResult (), posVal);
555+ bcastOp .getLoc (), aievecBcastType, extOp.getResult (), posVal);
556556 rewriter.replaceOpWithNewOp <aievec::ConcatOp>(
557- splatOp , resultType,
557+ bcastOp , resultType,
558558 SmallVector<Value>({aieBcastOp.getResult (), aieBcastOp.getResult ()}));
559559 } else {
560560 return failure ();
@@ -564,57 +564,57 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
564564 }
565565};
566566
567- struct ConvertSplatToAIEBroadcast : OpConversionPattern<vector::SplatOp > {
567+ struct ConvertSplatToAIEBroadcast : OpConversionPattern<vector::BroadcastOp > {
568568 using OpConversionPattern::OpConversionPattern;
569569
570570 LogicalResult
571- matchAndRewrite (vector::SplatOp splatOp , OpAdaptor adaptor,
571+ matchAndRewrite (vector::BroadcastOp bcastOp , OpAdaptor adaptor,
572572 ConversionPatternRewriter &rewriter) const override {
573573
574- if (adaptor.getInput ().getDefiningOp <vector::ExtractOp>())
574+ if (adaptor.getSource ().getDefiningOp <vector::ExtractOp>())
575575 return failure ();
576576
577- auto resultType = cast<VectorType>(splatOp .getResult ().getType ());
577+ auto resultType = cast<VectorType>(bcastOp .getResult ().getType ());
578578 auto flatResultType = getFlattenedVectorType (resultType);
579579 Type scalarType = resultType.getElementType ();
580580 unsigned elWidth = scalarType.getIntOrFloatBitWidth ();
581581 unsigned laneSize = getVectorLaneSize (resultType);
582- auto src = splatOp. getInput ();
582+ auto src = bcastOp. getSource ();
583583
584584 if (laneSize * elWidth == 512 ) {
585585 Value newOp = rewriter.create <aievec::BroadcastScalarOp>(
586- splatOp .getLoc (), flatResultType, src);
586+ bcastOp .getLoc (), flatResultType, src);
587587 if (resultType != flatResultType)
588- newOp = rewriter.create <vector::ShapeCastOp>(splatOp .getLoc (),
588+ newOp = rewriter.create <vector::ShapeCastOp>(bcastOp .getLoc (),
589589 resultType, newOp);
590- rewriter.replaceOp (splatOp , newOp);
590+ rewriter.replaceOp (bcastOp , newOp);
591591 return success ();
592592 }
593593
594594 if (laneSize * elWidth == 256 ) {
595595 VectorType vecType = createVectorType (512 / elWidth, scalarType);
596596 auto aieBcastOp = rewriter.create <aievec::BroadcastScalarOp>(
597- splatOp .getLoc (), vecType, src);
597+ bcastOp .getLoc (), vecType, src);
598598 Value newOp = rewriter.create <aievec::ExtOp>(
599- splatOp .getLoc (), flatResultType, aieBcastOp.getResult (), 0 );
599+ bcastOp .getLoc (), flatResultType, aieBcastOp.getResult (), 0 );
600600 if (resultType != flatResultType)
601- newOp = rewriter.create <vector::ShapeCastOp>(splatOp .getLoc (),
601+ newOp = rewriter.create <vector::ShapeCastOp>(bcastOp .getLoc (),
602602 resultType, newOp);
603- rewriter.replaceOp (splatOp , newOp);
603+ rewriter.replaceOp (bcastOp , newOp);
604604 return success ();
605605 }
606606
607607 if (laneSize * elWidth == 1024 ) {
608608 VectorType vecType = createVectorType (512 / elWidth, scalarType);
609609 auto aieBcastOp = rewriter.create <aievec::BroadcastScalarOp>(
610- splatOp .getLoc (), vecType, src);
610+ bcastOp .getLoc (), vecType, src);
611611 Value newOp = rewriter.create <aievec::ConcatOp>(
612- splatOp .getLoc (), flatResultType,
612+ bcastOp .getLoc (), flatResultType,
613613 SmallVector<Value>({aieBcastOp.getResult (), aieBcastOp.getResult ()}));
614614 if (resultType != flatResultType)
615- newOp = rewriter.create <vector::ShapeCastOp>(splatOp .getLoc (),
615+ newOp = rewriter.create <vector::ShapeCastOp>(bcastOp .getLoc (),
616616 resultType, newOp);
617- rewriter.replaceOp (splatOp , newOp);
617+ rewriter.replaceOp (bcastOp , newOp);
618618 return success ();
619619 }
620620
@@ -961,19 +961,19 @@ struct FoldSplatToFMAOp : OpConversionPattern<aievec::aie1::FMAOp> {
961961 dyn_cast<aievec::ConcatOp>(adaptor.getLhs ().getDefiningOp ());
962962 if (!concatOp)
963963 return failure ();
964- vector::SplatOp splatOp = nullptr ;
964+ vector::BroadcastOp bcastOp = nullptr ;
965965 auto *concatDefOp = concatOp.getSources ()[0 ].getDefiningOp ();
966966 if (concatDefOp)
967- splatOp = dyn_cast<vector::SplatOp >(concatDefOp);
967+ bcastOp = dyn_cast<vector::BroadcastOp >(concatDefOp);
968968 Value lhs = adaptor.getRhs ();
969- if (!splatOp ) {
970- splatOp = dyn_cast<vector::SplatOp >(adaptor.getRhs ().getDefiningOp ());
971- if (!splatOp )
969+ if (!bcastOp ) {
970+ bcastOp = dyn_cast<vector::BroadcastOp >(adaptor.getRhs ().getDefiningOp ());
971+ if (!bcastOp )
972972 return failure ();
973973 lhs = concatOp.getSources ()[0 ];
974974 }
975975 auto extOp =
976- dyn_cast<vector::ExtractOp>(splatOp. getInput ().getDefiningOp ());
976+ dyn_cast<vector::ExtractOp>(bcastOp. getSource ().getDefiningOp ());
977977 if (!extOp)
978978 return failure ();
979979
@@ -3540,18 +3540,18 @@ static void configureAIEVecV1Legalizations(ConversionTarget &target,
35403540 if (!concatOp)
35413541 return true ;
35423542
3543- vector::SplatOp srcSplat = nullptr ;
3543+ vector::BroadcastOp srcBcast = nullptr ;
35443544 if (auto *lhsOp = concatOp.getSources ()[0 ].getDefiningOp ())
3545- srcSplat = dyn_cast<vector::SplatOp >(lhsOp);
3546- if (!srcSplat ) {
3545+ srcBcast = dyn_cast<vector::BroadcastOp >(lhsOp);
3546+ if (!srcBcast ) {
35473547 auto *rhsOp = op.getRhs ().getDefiningOp ();
35483548 if (!rhsOp)
35493549 return true ;
3550- srcSplat = dyn_cast<vector::SplatOp >(rhsOp);
3550+ srcBcast = dyn_cast<vector::BroadcastOp >(rhsOp);
35513551 }
35523552
3553- if (srcSplat )
3554- if (auto *srcOp = srcSplat. getInput ().getDefiningOp ())
3553+ if (srcBcast )
3554+ if (auto *srcOp = srcBcast. getSource ().getDefiningOp ())
35553555 return !isa<vector::ExtractOp>(srcOp);
35563556
35573557 return true ;
0 commit comments