Skip to content

Commit 733cfef

Browse files
python3kgaeXiang Li
andauthored
Support gather scatter with mod (#284)
When add/mul 2 PtrStates, if any of the PtrState is not structured, set shape to 0 clear the mod. This enables gather/scatter when mod result is mul with unstructured PtrState. --------- Co-authored-by: Xiang Li <[email protected]>
1 parent 837fb98 commit 733cfef

File tree

12 files changed

+885
-244
lines changed

12 files changed

+885
-244
lines changed

include/triton-shared/AnalysisStructured/PtrAnalysis.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,15 @@ struct PtrState {
112112

113113
// Process addition of two PtrStates.
114114
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState,
115-
Operation *op, OpBuilder &builder);
115+
bool isAnalysisingUnstructured, Operation *op,
116+
OpBuilder &builder);
116117

117118
// Process multiplication of two PtrStates
118119
LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState,
119-
Operation *op, OpBuilder &builder);
120+
bool isAnalysisingUnstructured, Operation *op,
121+
OpBuilder &builder);
122+
123+
LogicalResult mergeUnstructuredState(const PtrState &other, Operation *op);
120124

121125
tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder,
122126
Location loc);
@@ -147,6 +151,41 @@ class PtrAnalysis {
147151

148152
DenseSet<Value> maybeStructuredArgs;
149153
const bool enableMakeGatherScatterTensorPtr;
154+
// If false, PtrAnalysis will analysis structured ptr while only identify
155+
// unstructured ptr.
156+
// If true, PtrAnalysis will caclulate strides and offsets for
157+
// unstructured pointers. This is used to support gather/scatter access.
158+
// The default mode is false. Only set to true when caclulating
159+
// unstructured pointers for gather/scatter access.
160+
// The reason to have different mode is to support case like:
161+
//
162+
// ptr + (row_offsets[:,None] % mod_offset + some_number) +
163+
// row_indices[:None]
164+
//
165+
// (row_offsets[:,None] % mod_offset + some_number) is structured and
166+
// has modulo.
167+
// row_indices[:, None] is unstructured.
168+
// When visiting the add operation, we need to apply the modulo to
169+
// (row_offsets[:,None] % mod_offset + some_number).
170+
// But we don't have the information about how to apply the modulo.
171+
//
172+
// To simplify the analysis, we do the work in two modes:
173+
// 1. Analyze to identify the unstructured pointers.
174+
// 2. Analyze to calculate the strides and offsets for unstructured pointers.
175+
// In mode 1, isAnalysisingUnstructured is set to false, so we only
176+
// identify the unstructured pointers and do not calculate the strides and
177+
// offsets.
178+
// When visit the operand again to calculate the offsets and strides for the
179+
// unstructured state, we'll set isAnalysisingUnstructured to true.
180+
// This means that we switched to mode 2 now and are analyzing the
181+
// unstructured pointers and calculating the strides and offsets for them. In
182+
// mode 2, we know that the pointer is unstructured, so we can just use the
183+
// value of arith::RemSIOp as offset directly. Once the analysis is done,
184+
// we'll switch back to mode 1.
185+
//
186+
// Note that this is might be a temporary solution, and we may need to
187+
// revisit this in the future to support more complex cases.
188+
bool isAnalysisingUnstructured = false;
150189

151190
public:
152191
PtrAnalysis(bool enableMakeGatherScatterTensorPtr)

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include "mlir/Transforms/DialectConversion.h"
1515
#include "triton/Dialect/Triton/IR/Dialect.h"
1616

17+
#include "llvm/Support/Debug.h"
18+
#define DEBUG_TYPE "triton-ptr-analysis"
19+
1720
namespace mlir {
1821

1922
std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
@@ -115,11 +118,32 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
115118
v = indexTypeCast(v, targetEltTy, loc, b);
116119
return b.create<triton::SplatOp>(loc, targetTy, v).getResult();
117120
} else if (targetShapedTy && shapedTy) {
118-
// TODO: support ShapedType to ShapedType.
119121
Type targetEltTy = targetShapedTy.getElementType();
120122
Type eltTy = shapedTy.getElementType();
121-
if (targetShapedTy.getShape() != shapedTy.getShape())
122-
llvm_unreachable("ShapedType to ShapedType must have same shape");
123+
if (targetShapedTy.getShape() != shapedTy.getShape()) {
124+
assert(targetEltTy == eltTy &&
125+
"Only cast between same element type shaped types");
126+
// This path is for case like:
127+
// input_ptr + (row_indices[:, None] + row_offsets[:,None] % mod_offset) *
128+
// stride_m + col_offsets[None, :] * stride_n
129+
// The modulo will be in shape of [ROW_SIZE, 1] while row_indices is in shape of [ROW_SIZE,].
130+
LLVM_DEBUG({
131+
llvm::dbgs() << "Reshaping ";
132+
shapedTy.dump();
133+
llvm::dbgs() << " to ";
134+
targetShapedTy.dump();
135+
});
136+
SmallVector<Value> shapeValues;
137+
for (auto dim : targetShapedTy.getShape()) {
138+
shapeValues.push_back(b.create<arith::ConstantOp>(
139+
loc, b.getIndexAttr(dim)));
140+
}
141+
RankedTensorType targetShapeTensorTy = RankedTensorType::get(
142+
targetShapedTy.getShape().size(), b.getIndexType());
143+
auto shapeTensor = b.create<tensor::FromElementsOp>(
144+
loc, targetShapeTensorTy, shapeValues);
145+
return b.create<triton::ReshapeOp>(loc, targetTy, v, shapeTensor).getResult();
146+
}
123147
if (isa<IndexType>(targetEltTy) || isa<IndexType>(eltTy)) {
124148
assert((isa<IntegerType>(targetEltTy) || isa<IntegerType>(eltTy)) &&
125149
"Only cast between index type and integer type");
@@ -351,4 +375,5 @@ OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
351375
auto selectOp = b.create<arith::SelectOp>(loc, cmpOp, trueValue, falseValue);
352376
return selectOp.getResult();
353377
}
378+
354379
} // namespace mlir

0 commit comments

Comments
 (0)