@@ -69,7 +69,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
6969 // This implementation always performs the max followed by min.
7070 // TODO: Is this going to work for degenerative floating point numbers?
7171 Value result = payloadArgs[0 ];
72- if (elemType. isa <mlir::FloatType>()) {
72+ if (isa<mlir::FloatType>(elemType )) {
7373 auto minFloat = clampOp.getMinFloat ();
7474 auto maxFloat = clampOp.getMaxFloat ();
7575 if (minFloat)
@@ -80,7 +80,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
8080 result = b.create <arith::MinimumFOp>(
8181 loc, result,
8282 b.create <arith::ConstantFloatOp>(loc, *maxFloat, b.getF32Type ()));
83- } else if (elemType. isa <mlir::IntegerType>()) {
83+ } else if (isa<mlir::IntegerType>(elemType )) {
8484 auto minInt = clampOp.getMinInt ();
8585 auto maxInt = clampOp.getMaxInt ();
8686 if (minInt)
@@ -136,9 +136,9 @@ createLinalgPayloadForElementwiseOp(Operation *op,
136136 }
137137
138138 if (isa<AbsOp>(op)) {
139- if (elemType. isa <mlir::FloatType>())
139+ if (isa<mlir::FloatType>(elemType ))
140140 return {b.create <math::AbsFOp>(loc, payloadArgs[0 ])};
141- else if (elemType. isa <mlir::IntegerType>())
141+ else if (isa<mlir::IntegerType>(elemType ))
142142 return {b.create <math::AbsIOp>(loc, payloadArgs[0 ])};
143143 else
144144 llvm_unreachable (" unsupported element type in "
@@ -158,45 +158,45 @@ createLinalgPayloadForElementwiseOp(Operation *op,
158158 }
159159
160160 if (isa<AddOp>(op)) {
161- if (elemType. isa <mlir::FloatType>())
161+ if (isa<mlir::FloatType>(elemType ))
162162 return {b.create <arith::AddFOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
163- else if (elemType. isa <mlir::IntegerType>())
163+ else if (isa<mlir::IntegerType>(elemType ))
164164 return {b.create <arith::AddIOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
165165 else
166166 llvm_unreachable (" unsupported element type in "
167167 " createLinalgPayloadForElementwiseOp for tcp.add" );
168168 }
169169
170170 if (isa<SubOp>(op)) {
171- if (elemType. isa <mlir::FloatType>())
171+ if (isa<mlir::FloatType>(elemType ))
172172 return {b.create <arith::SubFOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
173- else if (elemType. isa <mlir::IntegerType>())
173+ else if (isa<mlir::IntegerType>(elemType ))
174174 return {b.create <arith::SubIOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
175175 else
176176 llvm_unreachable (" unsupported element type in "
177177 " createLinalgPayloadForElementwiseOp fot tcp.sub" );
178178 }
179179
180180 if (isa<MulOp>(op)) {
181- if (elemType. isa <mlir::FloatType>())
181+ if (isa<mlir::FloatType>(elemType ))
182182 return {b.create <arith::MulFOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
183- else if (elemType. isa <mlir::IntegerType>())
183+ else if (isa<mlir::IntegerType>(elemType ))
184184 return {b.create <arith::MulIOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
185185 else
186186 llvm_unreachable (" unsupported element type in "
187187 " createLinalgPayloadForElementwiseOp for tcp.mul" );
188188 }
189189
190190 if (isa<DivFOp>(op)) {
191- if (elemType. isa <mlir::FloatType>())
191+ if (isa<mlir::FloatType>(elemType ))
192192 return {b.create <arith::DivFOp>(loc, payloadArgs[0 ], payloadArgs[1 ])};
193193 else
194194 llvm_unreachable (" unsupported element type in "
195195 " createLinalgPayloadForElementwiseOp for tcp.divf" );
196196 }
197197
198198 if (auto divOp = dyn_cast<DivSIOp>(op)) {
199- if (!elemType. isa <mlir::IntegerType>())
199+ if (!isa<mlir::IntegerType>(elemType ))
200200 llvm_unreachable (" unsupported element type in "
201201 " createLinalgPayloadForElementwiseOp for tcp.divsi" );
202202 if (divOp.getRoundingMode () == RoundingMode::Trunc)
@@ -210,7 +210,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
210210 }
211211
212212 if (auto divOp = dyn_cast<DivUIOp>(op)) {
213- if (!elemType. isa <mlir::IntegerType>())
213+ if (!isa<mlir::IntegerType>(elemType ))
214214 llvm_unreachable (" unsupported element type in "
215215 " createLinalgPayloadForElementwiseOp for tcp.divui" );
216216 if (divOp.getRoundingMode () == RoundingMode::Trunc ||
@@ -222,7 +222,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
222222 }
223223
224224 if (isa<Atan2Op>(op)) {
225- if (elemType. isa <mlir::FloatType>())
225+ if (isa<mlir::FloatType>(elemType ))
226226 return {b.create <math::Atan2Op>(loc, payloadArgs[0 ], payloadArgs[1 ])};
227227 else
228228 llvm_unreachable (" unsupported element type in "
@@ -231,7 +231,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
231231
232232 if (auto castOp = dyn_cast<CastOp>(op)) {
233233 auto inputType =
234- castOp.getIn ().getType (). dyn_cast <RankedTensorType>( ).getElementType ();
234+ dyn_cast<RankedTensorType>( castOp.getIn ().getType ()).getElementType ();
235235 auto outputType = resultTensorType.getElementType ();
236236
237237 if (inputType.getIntOrFloatBitWidth () ==
@@ -246,24 +246,24 @@ createLinalgPayloadForElementwiseOp(Operation *op,
246246 // To I1 (Bool) type
247247 Value cstZero =
248248 b.create <arith::ConstantOp>(loc, b.getZeroAttr (inputType));
249- if (inputType. isa <mlir::FloatType>()) {
249+ if (isa<mlir::FloatType>(inputType )) {
250250 return {b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
251251 payloadArgs[0 ], cstZero)};
252- } else if (inputType. isa <mlir::IntegerType>()) {
252+ } else if (isa<mlir::IntegerType>(inputType )) {
253253 return {b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
254254 payloadArgs[0 ], cstZero)};
255255 }
256- } else if (outputType. isa <mlir::FloatType>()) {
256+ } else if (isa<mlir::FloatType>(outputType )) {
257257 // TO FP type
258258 // FP -> FP
259- if (inputType. dyn_cast <mlir::FloatType>()) {
259+ if (dyn_cast<mlir::FloatType>(inputType )) {
260260 if (inputType.getIntOrFloatBitWidth () >
261261 outputType.getIntOrFloatBitWidth ())
262262 return {b.create <arith::TruncFOp>(loc, outputType, payloadArgs[0 ])};
263263 return {b.create <arith::ExtFOp>(loc, outputType, payloadArgs[0 ])};
264264 }
265265 // INT -> FP
266- else if (inputType. dyn_cast <mlir::IntegerType>()) {
266+ else if (dyn_cast<mlir::IntegerType>(inputType )) {
267267 // Signless or Unsigned INT to FP
268268 // Curently, signless is only for i1 (bool) case,
269269 // which has been handeled above
@@ -274,10 +274,10 @@ createLinalgPayloadForElementwiseOp(Operation *op,
274274 else if (castOp.getInIntSignedness ().value () == Signedness::Signed)
275275 return {b.create <arith::SIToFPOp>(loc, outputType, payloadArgs[0 ])};
276276 }
277- } else if (outputType. isa <mlir::IntegerType>()) {
277+ } else if (isa<mlir::IntegerType>(outputType )) {
278278 // TO INT type
279279 // FP -> INT
280- if (inputType. dyn_cast <mlir::FloatType>()) {
280+ if (dyn_cast<mlir::FloatType>(inputType )) {
281281 // FP to Signless or Unsigned INT
282282 if (castOp.getOutIntSignedness ().value () == Signedness::Signless ||
283283 castOp.getOutIntSignedness ().value () == Signedness::Unsigned)
@@ -287,7 +287,7 @@ createLinalgPayloadForElementwiseOp(Operation *op,
287287 return {b.create <arith::FPToSIOp>(loc, outputType, payloadArgs[0 ])};
288288 }
289289 // INT -> INT
290- if (inputType. dyn_cast <mlir::IntegerType>()) {
290+ if (dyn_cast<mlir::IntegerType>(inputType )) {
291291 if (inputType.getIntOrFloatBitWidth () >
292292 outputType.getIntOrFloatBitWidth ())
293293 return {b.create <arith::TruncIOp>(loc, outputType, payloadArgs[0 ])};
@@ -318,12 +318,11 @@ class ConvertElementwiseOp : public OpConversionPattern<TcpOpT> {
318318 matchAndRewrite (TcpOpT op, OpAdaptor adaptor,
319319 ConversionPatternRewriter &rewriter) const override {
320320 Location loc = op->getLoc ();
321- auto resultTensorType = OpConversionPattern<TcpOpT>::getTypeConverter ()
322- ->convertType (op->getResult (0 ).getType ())
323- .template cast <RankedTensorType>();
321+ auto resultTensorType = cast<RankedTensorType>(OpConversionPattern<TcpOpT>::getTypeConverter ()
322+ ->convertType (op->getResult (0 ).getType ()));
324323 auto tensorOperands = llvm::to_vector<6 >(
325324 llvm::make_filter_range (adaptor.getOperands (), [](Value v) {
326- return v. getType (). isa <RankedTensorType>();
325+ return isa<RankedTensorType>(v. getType () );
327326 }));
328327
329328 // Create Linalg payload
0 commit comments