@@ -772,33 +772,46 @@ Operation *FftOp::getTiledImplementation(OpBuilder &builder, ValueRange outputs,
772
772
// ===----------------------------------------------------------------------===//
773
773
774
774
static LogicalResult verifyScanOp (ScanOp op) {
775
- if (op.getNumInputs () != 2 ) {
776
- return op.emitOpError (" expected two input operands" );
775
+ if (op.getNumInputs () != 1 ) {
776
+ return op.emitOpError (" expected one input operands" );
777
777
}
778
- if (op.getNumOutputs () != 1 ) {
779
- return op.emitOpError (" expected one output operand " );
778
+ if (op.getNumOutputs () != 2 ) {
779
+ return op.emitOpError (" expected two output operands " );
780
780
}
781
781
if (!op.input ().getType ().isa <ShapedType>()) {
782
782
return op.emitOpError (" expected first input element type to be shaped" );
783
783
}
784
- auto identityElementType = op.identity ().getType ();
785
- if (!(identityElementType.isa <FloatType>() ||
786
- identityElementType.isa <IntegerType>())) {
787
- return op.emitOpError (
788
- " expected second input element type to be float or integer" );
789
- }
784
+ auto accumulatorType = op.accumulator ().getType ().cast <ShapedType>();
790
785
auto inputType = op.input ().getType ().cast <ShapedType>();
791
786
auto outputType = op.output ().getType ().cast <ShapedType>();
792
- if (identityElementType != inputType.getElementType ()) {
787
+ ArrayRef<int64_t > inputShapes = inputType.getShape ();
788
+ ArrayRef<int64_t > outputShapes = outputType.getShape ();
789
+ if (accumulatorType.getElementType () != inputType.getElementType ()) {
793
790
return op.emitOpError (
794
- " expected input/identity element types to be identical" );
791
+ " expected input/accumulator element types to be identical" );
792
+ }
793
+ ArrayRef<int64_t > accumulatorShape = accumulatorType.getShape ();
794
+ int64_t accumulatorRank = accumulatorType.getRank ();
795
+ if (accumulatorRank != inputType.getRank () - 1 ) {
796
+ return op.emitOpError (
797
+ " expected accumulator rank to be equal to input rank - 1" );
798
+ }
799
+ SmallVector<int64_t > expectedAccumulatorShape;
800
+ for (int i = 0 ; i < inputType.getRank (); i++) {
801
+ if (i != op.dimension ()) expectedAccumulatorShape.push_back (inputShapes[i]);
802
+ }
803
+ if (llvm::any_of (llvm::zip (expectedAccumulatorShape, accumulatorShape),
804
+ [](std::tuple<int64_t , int64_t > s) {
805
+ return std::get<0 >(s) != ShapedType::kDynamicSize &&
806
+ std::get<1 >(s) != ShapedType::kDynamicSize &&
807
+ std::get<0 >(s) != std::get<1 >(s);
808
+ })) {
809
+ return op.emitOpError (" incompatible input/accumulator shapes" );
795
810
}
796
811
if (inputType.getElementType () != outputType.getElementType ()) {
797
812
return op.emitOpError (
798
813
" expected input/output element types to be identical" );
799
814
}
800
- ArrayRef<int64_t > inputShapes = inputType.getShape ();
801
- ArrayRef<int64_t > outputShapes = outputType.getShape ();
802
815
if (inputShapes.size () != outputShapes.size ()) {
803
816
return op.emitOpError (" expected input/output to have identical ranks" );
804
817
}
@@ -862,14 +875,20 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
862
875
auto cond = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
863
876
indices[scanDim], zero);
864
877
bool isInclusive = inclusive ();
878
+ SmallVector<Value> accIndices;
879
+ for (int i = 0 ; i < indices.size (); i++) {
880
+ if (i != scanDim) accIndices.push_back (indices[i]);
881
+ }
882
+
865
883
auto scfIf = b.create <scf::IfOp>(
866
884
loc, TypeRange{}, cond,
867
885
[&](OpBuilder &b, Location loc) {
868
886
if (isInclusive) {
869
887
auto value = b.create <memref::LoadOp>(loc, input (), indices);
870
888
b.create <memref::StoreOp>(loc, value, output (), indices);
871
889
} else {
872
- b.create <memref::StoreOp>(loc, identity (), output (), indices);
890
+ auto value = b.create <memref::LoadOp>(loc, accumulator (), accIndices);
891
+ b.create <memref::StoreOp>(loc, value, output (), indices);
873
892
}
874
893
b.create <scf::YieldOp>(loc);
875
894
},
@@ -902,6 +921,9 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
902
921
b.create <memref::StoreOp>(
903
922
loc, bvm.lookupOrDefault (srcBlock.getTerminator ()->getOperand (0 )),
904
923
output (), indices);
924
+ b.create <memref::StoreOp>(
925
+ loc, bvm.lookupOrDefault (srcBlock.getTerminator ()->getOperand (0 )),
926
+ accumulator (), accIndices);
905
927
b.create <scf::YieldOp>(loc);
906
928
}
907
929
return success ();
@@ -922,25 +944,61 @@ Operation *ScanOp::getTiledImplementation(OpBuilder &builder,
922
944
SmallVector<Value> tiledOperands;
923
945
tiledOperands.emplace_back (
924
946
getSlice (builder, getLoc (), input (), offsets, sizes, strides));
925
- tiledOperands.emplace_back (identity ());
926
947
tiledOperands.emplace_back (
927
- getSlice (builder, getLoc (), output (), offsets, sizes, strides));
948
+ getSlice (builder, getLoc (), outputs[0 ], offsets, sizes, strides));
949
+ SmallVector<OpFoldResult> accumOffsets, accumSizes, accumStrides;
950
+ if (rank > 1 ) {
951
+ for (int i = 0 ; i < rank; i++) {
952
+ if (i != dimension ()) {
953
+ accumOffsets.push_back (offsets[i]);
954
+ accumSizes.push_back (sizes[i]);
955
+ accumStrides.push_back (strides[i]);
956
+ }
957
+ }
958
+ tiledOperands.emplace_back (getSlice (
959
+ builder, getLoc (), outputs[1 ], accumOffsets, accumSizes, accumStrides));
960
+ } else {
961
+ tiledOperands.emplace_back (outputs[1 ]);
962
+ }
928
963
929
964
SmallVector<Type, 4 > resultTypes;
930
965
if (hasTensorSemantics ()) {
966
+ resultTypes.push_back (tiledOperands[1 ].getType ());
931
967
resultTypes.push_back (tiledOperands[2 ].getType ());
932
968
}
933
969
934
970
Operation *tiledScanOp = cast<LinalgExtOp>(getOperation ())
935
971
.clone (builder, loc, resultTypes, tiledOperands);
936
972
for (auto result : llvm::enumerate (tiledScanOp->getResults ())) {
973
+ if ((result.index () == resultTypes.size () - 1 ) && (rank > 1 )) {
974
+ offsets = accumOffsets;
975
+ sizes = accumSizes;
976
+ strides = accumStrides;
977
+ }
937
978
auto insertSliceOp = builder.create <tensor::InsertSliceOp>(
938
979
loc, result.value (), outputs[result.index ()], offsets, sizes, strides);
939
980
results.push_back (insertSliceOp.getResult ());
940
981
}
941
982
return tiledScanOp;
942
983
}
943
984
985
+ static LogicalResult foldMemRefCast (Operation *op) {
986
+ bool folded = false ;
987
+ for (OpOperand &operand : op->getOpOperands ()) {
988
+ auto castOp = operand.get ().getDefiningOp <memref::CastOp>();
989
+ if (castOp && memref::CastOp::canFoldIntoConsumerOp (castOp)) {
990
+ operand.set (castOp.getOperand ());
991
+ folded = true ;
992
+ }
993
+ }
994
+ return success (folded);
995
+ }
996
+
997
+ LogicalResult ScanOp::fold (ArrayRef<Attribute>,
998
+ SmallVectorImpl<OpFoldResult> &) {
999
+ return foldMemRefCast (*this );
1000
+ }
1001
+
944
1002
// ===----------------------------------------------------------------------===//
945
1003
// ReverseOp
946
1004
// ===----------------------------------------------------------------------===//
0 commit comments