@@ -36,7 +36,8 @@ namespace {
36
36
// / Generic conversion for any DestinationStyleOpInterface on tensors.
37
37
static LogicalResult bufferizeTritonTilingExtDestinationStyleOpInterface (
38
38
RewriterBase &rewriter, DestinationStyleOpInterface op,
39
- const BufferizationOptions &options) {
39
+ const BufferizationOptions &options,
40
+ BufferizationState &state) {
40
41
// Take a guard before anything else.
41
42
OpBuilder::InsertionGuard g (rewriter);
42
43
rewriter.setInsertionPoint (op);
@@ -58,7 +59,7 @@ static LogicalResult bufferizeTritonTilingExtDestinationStyleOpInterface(
58
59
newInputBuffers.push_back (opOperand->get ());
59
60
continue ;
60
61
}
61
- FailureOr<Value> buffer = getBuffer (rewriter, opOperand->get (), options);
62
+ FailureOr<Value> buffer = getBuffer (rewriter, opOperand->get (), options, state );
62
63
if (failed (buffer))
63
64
return failure ();
64
65
newInputBuffers.push_back (*buffer);
@@ -69,7 +70,7 @@ static LogicalResult bufferizeTritonTilingExtDestinationStyleOpInterface(
69
70
for (OpResult opResult : op->getOpResults ()) {
70
71
OpOperand *opOperand = op.getDpsInitOperand (opResult.getResultNumber ());
71
72
FailureOr<Value> resultBuffer =
72
- getBuffer (rewriter, opOperand->get (), options);
73
+ getBuffer (rewriter, opOperand->get (), options, state );
73
74
if (failed (resultBuffer))
74
75
return failure ();
75
76
newOutputBuffers.push_back (*resultBuffer);
@@ -109,9 +110,10 @@ struct TritonTilingExtOpInterface
109
110
}
110
111
111
112
LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
112
- const BufferizationOptions &options) const {
113
+ const BufferizationOptions &options,
114
+ BufferizationState &state) const {
113
115
return bufferizeTritonTilingExtDestinationStyleOpInterface (
114
- rewriter, cast<DestinationStyleOpInterface>(op), options);
116
+ rewriter, cast<DestinationStyleOpInterface>(op), options, state );
115
117
}
116
118
};
117
119
0 commit comments