@@ -112,11 +112,15 @@ struct PtrState {
112
112
113
113
// Process addition of two PtrStates.
114
114
LogicalResult addState (const PtrState &lhsState, const PtrState &rhsState,
115
- Operation *op, OpBuilder &builder);
115
+ bool isAnalysisingUnstructured, Operation *op,
116
+ OpBuilder &builder);
116
117
117
118
// Process multiplication of two PtrStates
118
119
LogicalResult mulState (const PtrState &lhsState, const PtrState &rhsState,
119
- Operation *op, OpBuilder &builder);
120
+ bool isAnalysisingUnstructured, Operation *op,
121
+ OpBuilder &builder);
122
+
123
+ LogicalResult mergeUnstructuredState (const PtrState &other, Operation *op);
120
124
121
125
tts::MakeTensorPtrOp createTTSMakeTensorPtrOp (OpBuilder &builder,
122
126
Location loc);
@@ -147,6 +151,41 @@ class PtrAnalysis {
147
151
148
152
DenseSet<Value> maybeStructuredArgs;
149
153
const bool enableMakeGatherScatterTensorPtr;
154
+ // If false, PtrAnalysis will analysis structured ptr while only identify
155
+ // unstructured ptr.
156
+ // If true, PtrAnalysis will caclulate strides and offsets for
157
+ // unstructured pointers. This is used to support gather/scatter access.
158
+ // The default mode is false. Only set to true when caclulating
159
+ // unstructured pointers for gather/scatter access.
160
+ // The reason to have different mode is to support case like:
161
+ //
162
+ // ptr + (row_offsets[:,None] % mod_offset + some_number) +
163
+ // row_indices[:None]
164
+ //
165
+ // (row_offsets[:,None] % mod_offset + some_number) is structured and
166
+ // has modulo.
167
+ // row_indices[:, None] is unstructured.
168
+ // When visiting the add operation, we need to apply the modulo to
169
+ // (row_offsets[:,None] % mod_offset + some_number).
170
+ // But we don't have the information about how to apply the modulo.
171
+ //
172
+ // To simplify the analysis, we do the work in two modes:
173
+ // 1. Analyze to identify the unstructured pointers.
174
+ // 2. Analyze to calculate the strides and offsets for unstructured pointers.
175
+ // In mode 1, isAnalysisingUnstructured is set to false, so we only
176
+ // identify the unstructured pointers and do not calculate the strides and
177
+ // offsets.
178
+ // When visit the operand again to calculate the offsets and strides for the
179
+ // unstructured state, we'll set isAnalysisingUnstructured to true.
180
+ // This means that we switched to mode 2 now and are analyzing the
181
+ // unstructured pointers and calculating the strides and offsets for them. In
182
+ // mode 2, we know that the pointer is unstructured, so we can just use the
183
+ // value of arith::RemSIOp as offset directly. Once the analysis is done,
184
+ // we'll switch back to mode 1.
185
+ //
186
+ // Note that this is might be a temporary solution, and we may need to
187
+ // revisit this in the future to support more complex cases.
188
+ bool isAnalysisingUnstructured = false ;
150
189
151
190
public:
152
191
PtrAnalysis (bool enableMakeGatherScatterTensorPtr)
0 commit comments