Skip to content

Commit 2c4ab21

Browse files
author
Xiang Li
committed
Support index_select style gather/scatter access.
Save stride for unstructured dimension along chain of muls. Support i64 index for TTS_MakeGatherScatterTensorPtrOp. Fix the stride issue caused by the offset and strides of subview are relative to the strided memref of the input memref while the offset and stride of reinterpret_cast are relative to the base underlying memory of the memref.
1 parent 77e188d commit 2c4ab21

File tree

8 files changed

+1133
-111
lines changed

8 files changed

+1133
-111
lines changed

include/triton-shared/AnalysisStructured/PtrAnalysis.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ const extern std::string ptrAnalysisAttr;
4545
// unstructured offsets. Later, when using the tensor offset to calculate the
4646
// address, it will be collapsed to 1D. To support gather/scatter access, treat
4747
// the unstructured offset as a whole offset instead of decoding the pointer
48-
// arithmetic on it. The stride is set to 1 so it still matches the offset *
49-
// stride formula
48+
// arithmetic on it except scalar mul.
49+
// The stride is set to 1 when there's no scalar mul so it still matches the offset *
50+
// stride formula. When there're scalar muls, the stride is set to the multiplication
51+
// of all the scalar strides.
5052
struct PtrState {
5153
SmallVector<OpFoldResult> offsets;
5254
SmallVector<OpFoldResult> sizes;

include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def TTS_MakeTensorPtrOp
120120
//let hasCanonicalizer = 1;
121121
}
122122

123+
def TT_IndexTensorLike : AnyTypeOf<[I32Tensor, I64Tensor]>;
124+
123125
def TTS_MakeGatherScatterTensorPtrOp
124126
: TTS_Op<"make_gather_scatter_tptr", [AttrSizedOperandSegments, Pure]> {
125127
// NOTE: Only support cases where the offset for each dimension is defined in a different operation.
@@ -160,7 +162,7 @@ def TTS_MakeGatherScatterTensorPtrOp
160162
// result: A tensor of pointers.
161163

162164
let arguments = (ins TT_Ptr:$base,
163-
I32Tensor:$gather_scatter_offset,
165+
TT_IndexTensorLike:$gather_scatter_offset,
164166
I32Attr:$gather_scatter_dim,
165167
DenseI64ArrayAttr:$sizes,
166168
Variadic<Index>:$strides,

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,10 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
398398
OpFoldResult newOffset =
399399
mulOFRs(lhs->offsets[i], rhsStride, loc, builder);
400400
offsets.push_back(newOffset);
401-
// Set stride to 1 when not continuous.
402-
strides.push_back(builder.getIndexAttr(1));
401+
// Mul the scalart to stride.
402+
OpFoldResult newStride =
403+
mulOFRs(lhs->strides[i], rhs->scalar, loc, builder);
404+
strides.push_back(newStride);
403405
}
404406
OpFoldResult newShape =
405407
mulOFRs(lhs->shape[i], rhs->scalar, loc, builder);

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,9 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
865865
gatherOffsetElt.getResult(),
866866
gatherDim, rewriter);
867867
unsigned rank = ptr.getSizes().size();
868+
// Set strides to 1 for subview multiplies the existing strides with the
869+
// stride of the subview.
870+
SmallVector<OpFoldResult> oneStrides(rank, OpFoldResult(step));
868871
// subview from srcPtr for mask.
869872
// With offsets[gatherDim] set to 0 since the offset already in
870873
// reinterpret_cast. With sizes[gatherDim] set to 1 since we are load one
@@ -875,24 +878,24 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
875878
sizes = mixedDims;
876879
// maskOffsets should be all zero, since srcPtr already has the offsets.
877880
SmallVector<OpFoldResult> maskOffsets(rank, OpFoldResult(lowerBound));
878-
// Use allocStrides for subview.
881+
// Use oneStrides for subview.
879882
auto dstSubViewType = memref::SubViewOp::inferResultType(
880-
cast<MemRefType>(srcPtr.getType()), maskOffsets, sizes, allocStrides);
883+
cast<MemRefType>(srcPtr.getType()), maskOffsets, sizes, oneStrides);
881884
srcPtr =
882885
rewriter
883886
.create<memref::SubViewOp>(loc, cast<MemRefType>(dstSubViewType),
884-
srcPtr, maskOffsets, sizes, allocStrides)
887+
srcPtr, maskOffsets, sizes, oneStrides)
885888
.getResult();
886889
}
887890

888891
// alloc[inductionVar]
889892
SmallVector<OpFoldResult> allocOffsets(rank, OpFoldResult(lowerBound));
890893
allocOffsets[gatherDim] = inductionVar;
891894
auto dstAllocType = memref::SubViewOp::inferResultType(
892-
allocType, allocOffsets, sizes, allocStrides);
895+
allocType, allocOffsets, sizes, oneStrides);
893896
auto dstSubview = rewriter.create<memref::SubViewOp>(
894897
loc, cast<MemRefType>(dstAllocType), alloc, allocOffsets, sizes,
895-
allocStrides);
898+
oneStrides);
896899
// Copy srcPtr to alloc[inductionVar].
897900
rewriter.create<memref::CopyOp>(loc, srcPtr, dstSubview);
898901

@@ -990,8 +993,11 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
990993
mixedDims[gatherDim] = sizes[gatherDim];
991994
sizes = mixedDims;
992995
}
996+
// Set strides to 1 for subview/extract_slice multiplies the existing strides with the
997+
// stride of the subview.
998+
SmallVector<OpFoldResult> oneStrides(rank, OpFoldResult(step));
993999
auto slice = rewriter.create<tensor::ExtractSliceOp>(
994-
loc, stVal, stValOffsets, sizes, strides);
1000+
loc, stVal, stValOffsets, sizes, oneStrides);
9951001

9961002
// reinterpret_cast to current row as memRefPtr[gatherOffsetElt].
9971003
Value dstPtr = rewriteGatherScatterPtrElement(staticSizes, ptr, memRefPtr,
@@ -1003,15 +1009,14 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
10031009
// maskOffsets should be all zero, since srcPtr already has the offsets.
10041010
SmallVector<OpFoldResult> maskOffsets(rank, OpFoldResult(lowerBound));
10051011
auto dstType = memref::SubViewOp::inferResultType(
1006-
cast<MemRefType>(dstPtr.getType()), maskOffsets, sizes, strides);
1012+
cast<MemRefType>(dstPtr.getType()), maskOffsets, sizes, oneStrides);
10071013

10081014
dstPtr =
10091015
rewriter
10101016
.create<memref::SubViewOp>(loc, cast<MemRefType>(dstType), dstPtr,
1011-
maskOffsets, sizes, strides)
1017+
maskOffsets, sizes, oneStrides)
10121018
.getResult();
10131019
}
1014-
10151020
// store slice to dstPtr.
10161021
auto storeOp = rewriter.create<bufferization::MaterializeInDestinationOp>(
10171022
loc, slice, dstPtr);

0 commit comments

Comments
 (0)