Skip to content

Commit be1ea63

Browse files
committed
at least build
1 parent de86b85 commit be1ea63

File tree

12 files changed

+182
-233
lines changed

12 files changed

+182
-233
lines changed

deps.bzl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def third_party_deps():
2020
path = local_llvm_repo_path(),
2121
)
2222
else:
23-
LLVM_COMMIT = "72144d119a7291f8b6b8e022a2947fbe31e66afc"
24-
LLVM_SHA256 = "2caacb6925a13cb5886a5d7f225fa408b80ca8e1efe0736186954b2abc4ee1c3"
23+
LLVM_COMMIT = "eda3e96b401a9b86132e39432e41e2000d1ab382"
24+
LLVM_SHA256 = "26c4060f19982482d57f1a47945f3f7613b7659415f0482c4bac63769366b501"
2525
http_archive(
2626
name = "llvm-raw",
2727
build_file_content = "# empty",
@@ -37,8 +37,8 @@ def third_party_deps():
3737
path = local_torch_mlir_repo_path(),
3838
)
3939
else:
40-
TORCH_MLIR_COMMIT = "9f2ba5abaa85cefd95cc85579fafd0c53c1101e8"
41-
TORCH_MLIR_SHA256 = "09444281839eeae4aff42c029d87b1728f307fa26511b896ff448d51aaa98049"
40+
TORCH_MLIR_COMMIT = "1ad9702d2a290b693c4f6f17921d0e0a8d14a999"
41+
TORCH_MLIR_SHA256 = "8843399168c34ca3ca16d2417703fe4e1440ca7240d9e04844b3deedf256f0ab"
4242
http_archive(
4343
name = "torch-mlir-raw",
4444
build_file_content = "# empty",

include/mlir-tcp/Dialect/IR/TcpTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ include "mlir-tcp/Dialect/IR/TcpBase.td"
2424
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
2525
// the 8-bit case.
2626
class Tcp_QuantizedType<string n, list<int> params, bit signed>
27-
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
28-
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
27+
: Type<And<[CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">,
28+
CPred<"::llvm::cast<mlir::quant::QuantizedType>($_self)" #
2929
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
3030
"Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
3131
string name = n;

lib/Conversion/TcpToLinalg/DataMovement.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
3636
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
3737
ConversionPatternRewriter &rewriter) const override {
3838
Location loc = op->getLoc();
39-
auto resultTensorType = getTypeConverter()
40-
->convertType(op.getOut().getType())
41-
.cast<RankedTensorType>();
39+
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
40+
->convertType(op.getOut().getType()));
4241

4342
auto inputTensor = adaptor.getInput();
4443
auto indicesTensor = adaptor.getIndices();
@@ -110,9 +109,8 @@ class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
110109
matchAndRewrite(GatherNDOp op, OpAdaptor adaptor,
111110
ConversionPatternRewriter &rewriter) const override {
112111
Location loc = op->getLoc();
113-
auto resultTensorType = getTypeConverter()
114-
->convertType(op.getOut().getType())
115-
.cast<RankedTensorType>();
112+
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
113+
->convertType(op.getOut().getType()));
116114

117115
auto inputTensor = adaptor.getInput();
118116
auto indicesTensor = adaptor.getIndices();

lib/Conversion/TcpToLinalg/Elementwise.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
6969
// This implementation always performs the max followed by min.
7070
// TODO: Is this going to work for degenerative floating point numbers?
7171
Value result = payloadArgs[0];
72-
if (elemType.isa<mlir::FloatType>()) {
72+
if (isa<mlir::FloatType>(elemType)) {
7373
auto minFloat = clampOp.getMinFloat();
7474
auto maxFloat = clampOp.getMaxFloat();
7575
if (minFloat)
@@ -80,7 +80,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
8080
result = b.create<arith::MinimumFOp>(
8181
loc, result,
8282
b.create<arith::ConstantFloatOp>(loc, *maxFloat, b.getF32Type()));
83-
} else if (elemType.isa<mlir::IntegerType>()) {
83+
} else if (isa<mlir::IntegerType>(elemType)) {
8484
auto minInt = clampOp.getMinInt();
8585
auto maxInt = clampOp.getMaxInt();
8686
if (minInt)
@@ -136,9 +136,9 @@ createLinalgPayloadForElementwiseOp(Operation *op,
136136
}
137137

138138
if (isa<AbsOp>(op)) {
139-
if (elemType.isa<mlir::FloatType>())
139+
if (isa<mlir::FloatType>(elemType))
140140
return {b.create<math::AbsFOp>(loc, payloadArgs[0])};
141-
else if (elemType.isa<mlir::IntegerType>())
141+
else if (isa<mlir::IntegerType>(elemType))
142142
return {b.create<math::AbsIOp>(loc, payloadArgs[0])};
143143
else
144144
llvm_unreachable("unsupported element type in "
@@ -158,45 +158,45 @@ createLinalgPayloadForElementwiseOp(Operation *op,
158158
}
159159

160160
if (isa<AddOp>(op)) {
161-
if (elemType.isa<mlir::FloatType>())
161+
if (isa<mlir::FloatType>(elemType))
162162
return {b.create<arith::AddFOp>(loc, payloadArgs[0], payloadArgs[1])};
163-
else if (elemType.isa<mlir::IntegerType>())
163+
else if (isa<mlir::IntegerType>(elemType))
164164
return {b.create<arith::AddIOp>(loc, payloadArgs[0], payloadArgs[1])};
165165
else
166166
llvm_unreachable("unsupported element type in "
167167
"createLinalgPayloadForElementwiseOp for tcp.add");
168168
}
169169

170170
if (isa<SubOp>(op)) {
171-
if (elemType.isa<mlir::FloatType>())
171+
if (isa<mlir::FloatType>(elemType))
172172
return {b.create<arith::SubFOp>(loc, payloadArgs[0], payloadArgs[1])};
173-
else if (elemType.isa<mlir::IntegerType>())
173+
else if (isa<mlir::IntegerType>(elemType))
174174
return {b.create<arith::SubIOp>(loc, payloadArgs[0], payloadArgs[1])};
175175
else
176176
llvm_unreachable("unsupported element type in "
177177
"createLinalgPayloadForElementwiseOp fot tcp.sub");
178178
}
179179

180180
if (isa<MulOp>(op)) {
181-
if (elemType.isa<mlir::FloatType>())
181+
if (isa<mlir::FloatType>(elemType))
182182
return {b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1])};
183-
else if (elemType.isa<mlir::IntegerType>())
183+
else if (isa<mlir::IntegerType>(elemType))
184184
return {b.create<arith::MulIOp>(loc, payloadArgs[0], payloadArgs[1])};
185185
else
186186
llvm_unreachable("unsupported element type in "
187187
"createLinalgPayloadForElementwiseOp for tcp.mul");
188188
}
189189

190190
if (isa<DivFOp>(op)) {
191-
if (elemType.isa<mlir::FloatType>())
191+
if (isa<mlir::FloatType>(elemType))
192192
return {b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1])};
193193
else
194194
llvm_unreachable("unsupported element type in "
195195
"createLinalgPayloadForElementwiseOp for tcp.divf");
196196
}
197197

198198
if (auto divOp = dyn_cast<DivSIOp>(op)) {
199-
if (!elemType.isa<mlir::IntegerType>())
199+
if (!isa<mlir::IntegerType>(elemType))
200200
llvm_unreachable("unsupported element type in "
201201
"createLinalgPayloadForElementwiseOp for tcp.divsi");
202202
if (divOp.getRoundingMode() == RoundingMode::Trunc)
@@ -210,7 +210,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
210210
}
211211

212212
if (auto divOp = dyn_cast<DivUIOp>(op)) {
213-
if (!elemType.isa<mlir::IntegerType>())
213+
if (!isa<mlir::IntegerType>(elemType))
214214
llvm_unreachable("unsupported element type in "
215215
"createLinalgPayloadForElementwiseOp for tcp.divui");
216216
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
@@ -222,7 +222,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
222222
}
223223

224224
if (isa<Atan2Op>(op)) {
225-
if (elemType.isa<mlir::FloatType>())
225+
if (isa<mlir::FloatType>(elemType))
226226
return {b.create<math::Atan2Op>(loc, payloadArgs[0], payloadArgs[1])};
227227
else
228228
llvm_unreachable("unsupported element type in "
@@ -231,7 +231,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
231231

232232
if (auto castOp = dyn_cast<CastOp>(op)) {
233233
auto inputType =
234-
castOp.getIn().getType().dyn_cast<RankedTensorType>().getElementType();
234+
dyn_cast<RankedTensorType>(castOp.getIn().getType()).getElementType();
235235
auto outputType = resultTensorType.getElementType();
236236

237237
if (inputType.getIntOrFloatBitWidth() ==
@@ -246,24 +246,24 @@ createLinalgPayloadForElementwiseOp(Operation *op,
246246
// To I1 (Bool) type
247247
Value cstZero =
248248
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputType));
249-
if (inputType.isa<mlir::FloatType>()) {
249+
if (isa<mlir::FloatType>(inputType)) {
250250
return {b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
251251
payloadArgs[0], cstZero)};
252-
} else if (inputType.isa<mlir::IntegerType>()) {
252+
} else if (isa<mlir::IntegerType>(inputType)) {
253253
return {b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
254254
payloadArgs[0], cstZero)};
255255
}
256-
} else if (outputType.isa<mlir::FloatType>()) {
256+
} else if (isa<mlir::FloatType>(outputType)) {
257257
// TO FP type
258258
// FP -> FP
259-
if (inputType.dyn_cast<mlir::FloatType>()) {
259+
if (dyn_cast<mlir::FloatType>(inputType)) {
260260
if (inputType.getIntOrFloatBitWidth() >
261261
outputType.getIntOrFloatBitWidth())
262262
return {b.create<arith::TruncFOp>(loc, outputType, payloadArgs[0])};
263263
return {b.create<arith::ExtFOp>(loc, outputType, payloadArgs[0])};
264264
}
265265
// INT -> FP
266-
else if (inputType.dyn_cast<mlir::IntegerType>()) {
266+
else if (dyn_cast<mlir::IntegerType>(inputType)) {
267267
// Signless or Unsigned INT to FP
268268
// Curently, signless is only for i1 (bool) case,
269269
// which has been handeled above
@@ -274,10 +274,10 @@ createLinalgPayloadForElementwiseOp(Operation *op,
274274
else if (castOp.getInIntSignedness().value() == Signedness::Signed)
275275
return {b.create<arith::SIToFPOp>(loc, outputType, payloadArgs[0])};
276276
}
277-
} else if (outputType.isa<mlir::IntegerType>()) {
277+
} else if (isa<mlir::IntegerType>(outputType)) {
278278
// TO INT type
279279
// FP -> INT
280-
if (inputType.dyn_cast<mlir::FloatType>()) {
280+
if (dyn_cast<mlir::FloatType>(inputType)) {
281281
// FP to Signless or Unsigned INT
282282
if (castOp.getOutIntSignedness().value() == Signedness::Signless ||
283283
castOp.getOutIntSignedness().value() == Signedness::Unsigned)
@@ -287,7 +287,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
287287
return {b.create<arith::FPToSIOp>(loc, outputType, payloadArgs[0])};
288288
}
289289
// INT -> INT
290-
if (inputType.dyn_cast<mlir::IntegerType>()) {
290+
if (dyn_cast<mlir::IntegerType>(inputType)) {
291291
if (inputType.getIntOrFloatBitWidth() >
292292
outputType.getIntOrFloatBitWidth())
293293
return {b.create<arith::TruncIOp>(loc, outputType, payloadArgs[0])};
@@ -318,12 +318,11 @@ class ConvertElementwiseOp : public OpConversionPattern<TcpOpT> {
318318
matchAndRewrite(TcpOpT op, OpAdaptor adaptor,
319319
ConversionPatternRewriter &rewriter) const override {
320320
Location loc = op->getLoc();
321-
auto resultTensorType = OpConversionPattern<TcpOpT>::getTypeConverter()
322-
->convertType(op->getResult(0).getType())
323-
.template cast<RankedTensorType>();
321+
auto resultTensorType = cast<RankedTensorType>(OpConversionPattern<TcpOpT>::getTypeConverter()
322+
->convertType(op->getResult(0).getType()));
324323
auto tensorOperands = llvm::to_vector<6>(
325324
llvm::make_filter_range(adaptor.getOperands(), [](Value v) {
326-
return v.getType().isa<RankedTensorType>();
325+
return isa<RankedTensorType>(v.getType());
327326
}));
328327

329328
// Create Linalg payload

lib/Conversion/TcpToLinalg/Misc.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace {
2929
SmallVector<int64_t> getValuesFromIndexArrayAttribute(ArrayAttr attr) {
3030
SmallVector<int64_t> arrayValues;
3131
for (Attribute val : attr.getValue())
32-
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
32+
arrayValues.push_back(cast<IntegerAttr>(val).getValue().getSExtValue());
3333
return arrayValues;
3434
}
3535

@@ -40,9 +40,8 @@ class ConvertBroadcastOp : public OpConversionPattern<BroadcastOp> {
4040
LogicalResult matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
4141
ConversionPatternRewriter &b) const override {
4242
Location loc = op->getLoc();
43-
auto resultTensorType = getTypeConverter()
44-
->convertType(op->getResult(0).getType())
45-
.cast<RankedTensorType>();
43+
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
44+
->convertType(op->getResult(0).getType()));
4645
auto inputTensor = op->getOperands()[0];
4746

4847
SmallVector<int64_t> axes = getValuesFromIndexArrayAttribute(op.getAxes());

lib/Conversion/TorchToTcp/DataMovement.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
4141
Location loc = op.getLoc();
4242
auto input = adaptor.getSelf();
4343
RankedTensorType inputType =
44-
input.getType().template cast<RankedTensorType>();
44+
cast<RankedTensorType>(input.getType());
4545

4646
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4747
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -64,8 +64,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
6464
Value builtinTypeStart = adaptor.getStart();
6565
Value builtinTypeEnd = adaptor.getEnd();
6666

67-
if (torchTypeStart.getType().isa<OptionalType>() ||
68-
torchTypeEnd.getType().isa<OptionalType>())
67+
if (isa<OptionalType>(torchTypeStart.getType()) ||
68+
isa<OptionalType>(torchTypeEnd.getType()))
6969
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
7070

7171
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
@@ -75,7 +75,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
7575
// We cannot use to positive valid dim as for negative strides we need to
7676
// clamp to `-1` so that the full tensor bounds are available:
7777
Value end = builtinTypeEnd;
78-
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
78+
if (isa<Torch::NoneType>(torchTypeEnd.getType())) {
7979
end = dimSize;
8080
} else {
8181
end = castIntToIndex(rewriter, loc, end);
@@ -140,7 +140,7 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
140140
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
141141

142142
RankedTensorType newResultType =
143-
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
143+
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
144144
int rank = newResultType.getRank();
145145
Value dimValue = op.getDim();
146146
int64_t dim;
@@ -185,9 +185,8 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
185185
return failure();
186186

187187
auto input = adaptor.getSelf();
188-
RankedTensorType resultType = getTypeConverter()
189-
->convertType(op->getResult(0).getType())
190-
.cast<RankedTensorType>();
188+
RankedTensorType resultType = cast<RankedTensorType>(getTypeConverter()
189+
->convertType(op->getResult(0).getType()));
191190

192191
SmallVector<Value> resultShape;
193192
SmallVector<Value> offsets;
@@ -213,14 +212,13 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
213212
ConversionPatternRewriter &rewriter) const override {
214213
auto input = adaptor.getSelf();
215214
auto indices = adaptor.getIndex();
216-
RankedTensorType resultType = getTypeConverter()
217-
->convertType(op->getResult(0).getType())
218-
.template cast<RankedTensorType>();
215+
RankedTensorType resultType = cast<RankedTensorType>(getTypeConverter()
216+
->convertType(op->getResult(0).getType()));
219217

220218
int64_t dim = 0;
221219
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
222220
return op.emitError("dim on torch.gather must be an int constant");
223-
auto inputType = input.getType().cast<RankedTensorType>();
221+
auto inputType = cast<RankedTensorType>(input.getType());
224222
dim = Torch::toPositiveDim(dim, inputType.getRank());
225223

226224
bool sparseGrad = false;

0 commit comments

Comments
 (0)