@@ -865,6 +865,9 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
865
865
gatherOffsetElt.getResult (),
866
866
gatherDim, rewriter);
867
867
unsigned rank = ptr.getSizes ().size ();
868
+ // Set strides to 1 for subview multiplies the existing strides with the
869
+ // stride of the subview.
870
+ SmallVector<OpFoldResult> oneStrides (rank, OpFoldResult (step));
868
871
// subview from srcPtr for mask.
869
872
// With offsets[gatherDim] set to 0 since the offset already in
870
873
// reinterpret_cast. With sizes[gatherDim] set to 1 since we are load one
@@ -875,24 +878,24 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
875
878
sizes = mixedDims;
876
879
// maskOffsets should be all zero, since srcPtr already has the offsets.
877
880
SmallVector<OpFoldResult> maskOffsets (rank, OpFoldResult (lowerBound));
878
- // Use allocStrides for subview.
881
+ // Use oneStrides for subview.
879
882
auto dstSubViewType = memref::SubViewOp::inferResultType (
880
- cast<MemRefType>(srcPtr.getType ()), maskOffsets, sizes, allocStrides );
883
+ cast<MemRefType>(srcPtr.getType ()), maskOffsets, sizes, oneStrides );
881
884
srcPtr =
882
885
rewriter
883
886
.create <memref::SubViewOp>(loc, cast<MemRefType>(dstSubViewType),
884
- srcPtr, maskOffsets, sizes, allocStrides )
887
+ srcPtr, maskOffsets, sizes, oneStrides )
885
888
.getResult ();
886
889
}
887
890
888
891
// alloc[inductionVar]
889
892
SmallVector<OpFoldResult> allocOffsets (rank, OpFoldResult (lowerBound));
890
893
allocOffsets[gatherDim] = inductionVar;
891
894
auto dstAllocType = memref::SubViewOp::inferResultType (
892
- allocType, allocOffsets, sizes, allocStrides );
895
+ allocType, allocOffsets, sizes, oneStrides );
893
896
auto dstSubview = rewriter.create <memref::SubViewOp>(
894
897
loc, cast<MemRefType>(dstAllocType), alloc, allocOffsets, sizes,
895
- allocStrides );
898
+ oneStrides );
896
899
// Copy srcPtr to alloc[inductionVar].
897
900
rewriter.create <memref::CopyOp>(loc, srcPtr, dstSubview);
898
901
@@ -990,8 +993,11 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
990
993
mixedDims[gatherDim] = sizes[gatherDim];
991
994
sizes = mixedDims;
992
995
}
996
+ // Set strides to 1 for subview/extract_slice multiplies the existing strides with the
997
+ // stride of the subview.
998
+ SmallVector<OpFoldResult> oneStrides (rank, OpFoldResult (step));
993
999
auto slice = rewriter.create <tensor::ExtractSliceOp>(
994
- loc, stVal, stValOffsets, sizes, strides );
1000
+ loc, stVal, stValOffsets, sizes, oneStrides );
995
1001
996
1002
// reinterpret_cast to current row as memRefPtr[gatherOffsetElt].
997
1003
Value dstPtr = rewriteGatherScatterPtrElement (staticSizes, ptr, memRefPtr,
@@ -1003,15 +1009,14 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
1003
1009
// maskOffsets should be all zero, since srcPtr already has the offsets.
1004
1010
SmallVector<OpFoldResult> maskOffsets (rank, OpFoldResult (lowerBound));
1005
1011
auto dstType = memref::SubViewOp::inferResultType (
1006
- cast<MemRefType>(dstPtr.getType ()), maskOffsets, sizes, strides );
1012
+ cast<MemRefType>(dstPtr.getType ()), maskOffsets, sizes, oneStrides );
1007
1013
1008
1014
dstPtr =
1009
1015
rewriter
1010
1016
.create <memref::SubViewOp>(loc, cast<MemRefType>(dstType), dstPtr,
1011
- maskOffsets, sizes, strides )
1017
+ maskOffsets, sizes, oneStrides )
1012
1018
.getResult ();
1013
1019
}
1014
-
1015
1020
// store slice to dstPtr.
1016
1021
auto storeOp = rewriter.create <bufferization::MaterializeInDestinationOp>(
1017
1022
loc, slice, dstPtr);
0 commit comments