Skip to content

Commit b43def2

Browse files
committed
[CIR] Refactor complex type
Apply `CIR_AnyIntOrFloat` type constraint on element type and rename `elementTy` to `elementType`.
1 parent 64bf534 commit b43def2

File tree

12 files changed

+91
-81
lines changed

12 files changed

+91
-81
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
311311

312312
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
313313
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
314-
return create<cir::ComplexRealOp>(loc, operandTy.getElementTy(), operand);
314+
return create<cir::ComplexRealOp>(loc, operandTy.getElementType(), operand);
315315
}
316316

317317
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
318318
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
319-
return create<cir::ComplexImagOp>(loc, operandTy.getElementTy(), operand);
319+
return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), operand);
320320
}
321321

322322
mlir::Value createComplexBinOp(mlir::Location loc, mlir::Value lhs,

clang/include/clang/CIR/Dialect/IR/CIROps.td

+8-5
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,10 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
13861386
}];
13871387

13881388
let results = (outs CIR_ComplexType:$result);
1389-
let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag);
1389+
let arguments = (ins
1390+
CIR_AnyIntOrFloatType:$real,
1391+
CIR_AnyIntOrFloatType:$imag
1392+
);
13901393

13911394
let assemblyFormat = [{
13921395
$real `,` $imag
@@ -1414,7 +1417,7 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
14141417
```
14151418
}];
14161419

1417-
let results = (outs CIR_AnyIntOrFloat:$result);
1420+
let results = (outs CIR_AnyIntOrFloatType:$result);
14181421
let arguments = (ins CIR_ComplexType:$operand);
14191422

14201423
let assemblyFormat = [{
@@ -1439,7 +1442,7 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
14391442
```
14401443
}];
14411444

1442-
let results = (outs CIR_AnyIntOrFloat:$result);
1445+
let results = (outs CIR_AnyIntOrFloatType:$result);
14431446
let arguments = (ins CIR_ComplexType:$operand);
14441447

14451448
let assemblyFormat = [{
@@ -5564,9 +5567,9 @@ def AtomicFetch : CIR_Op<"atomic.fetch",
55645567
%res = cir.atomic.fetch(add, %ptr : !cir.ptr<!s32i>,
55655568
%val : !s32i, seq_cst) : !s32i
55665569
}];
5567-
let results = (outs CIR_AnyIntOrFloat:$result);
5570+
let results = (outs CIR_AnyIntOrFloatType:$result);
55685571
let arguments = (ins Arg<PrimitiveIntOrFPPtr, "", [MemRead, MemWrite]>:$ptr,
5569-
CIR_AnyIntOrFloat:$val,
5572+
CIR_AnyIntOrFloatType:$val,
55705573
AtomicFetchKind:$binop,
55715574
Arg<MemOrder, "memory order">:$mem_order,
55725575
UnitAttr:$is_volatile,

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

+11-1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ def CIR_AnyFloatType : AnyTypeOf<[
135135
let cppFunctionName = "isAnyFloatingPointType";
136136
}
137137

138-
def CIR_AnyIntOrFloat : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType]>;
138+
def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
139+
"integer or floating point type"
140+
> {
141+
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
142+
}
143+
144+
//===----------------------------------------------------------------------===//
145+
// Complex Type predicates
146+
//===----------------------------------------------------------------------===//
147+
148+
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
139149

140150
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+18-7
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,35 @@ def CIR_ComplexType : CIR_Type<"Complex", "complex",
159159
CIR type that represents a C complex number. `cir.complex` models the C type
160160
`T _Complex`.
161161

162-
The parameter `elementTy` gives the type of the real and imaginary part of
163-
the complex number. `elementTy` must be either a CIR integer type or a CIR
162+
The type models complex values, per C99 6.2.5p11. It supports the C99
163+
complex float types as well as the GCC integer complex extensions.
164+
165+
The parameter `elementType` gives the type of the real and imaginary part of
166+
the complex number. `elementType` must be either a CIR integer type or a CIR
164167
floating-point type.
165168
}];
166169

167-
let parameters = (ins "mlir::Type":$elementTy);
170+
let parameters = (ins CIR_AnyIntOrFloatType:$elementType);
168171

169172
let builders = [
170-
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementTy), [{
171-
return $_get(elementTy.getContext(), elementTy);
173+
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
174+
return $_get(elementType.getContext(), elementType);
172175
}]>,
173176
];
174177

175178
let assemblyFormat = [{
176-
`<` $elementTy `>`
179+
`<` $elementType `>`
177180
}];
178181

179-
let genVerifyDecl = 1;
182+
let extraClassDeclaration = [{
183+
bool isFloatingPointComplex() const {
184+
return isAnyFloatingPointType(getElementType());
185+
}
186+
187+
bool isIntegerComplex() const {
188+
return mlir::isa<cir::IntType>(getElementType());
189+
}
190+
}];
180191
}
181192

182193
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
816816
auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
817817
auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
818818
return create<cir::ComplexRealPtrOp>(
819-
loc, getPointerTo(srcComplexTy.getElementTy()), value);
819+
loc, getPointerTo(srcComplexTy.getElementType()), value);
820820
}
821821

822822
Address createRealPtr(mlir::Location loc, Address addr) {
@@ -830,7 +830,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
830830
auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
831831
auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
832832
return create<cir::ComplexImagPtrOp>(
833-
loc, getPointerTo(srcComplexTy.getElementTy()), value);
833+
loc, getPointerTo(srcComplexTy.getElementType()), value);
834834
}
835835

836836
Address createImagPtr(mlir::Location loc, Address addr) {

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ mlir::Value
831831
ComplexExprEmitter::VisitImaginaryLiteral(const ImaginaryLiteral *IL) {
832832
auto Loc = CGF.getLoc(IL->getExprLoc());
833833
auto Ty = mlir::cast<cir::ComplexType>(CGF.convertType(IL->getType()));
834-
auto ElementTy = Ty.getElementTy();
834+
auto ElementTy = Ty.getElementType();
835835

836836
mlir::TypedAttr RealValueAttr;
837837
mlir::TypedAttr ImagValueAttr;
@@ -875,17 +875,15 @@ mlir::Value CIRGenFunction::emitPromotedComplexExpr(const Expr *E,
875875

876876
mlir::Value CIRGenFunction::emitPromotedValue(mlir::Value result,
877877
QualType PromotionType) {
878-
assert(mlir::isa<cir::CIRFPTypeInterface>(
879-
mlir::cast<cir::ComplexType>(result.getType()).getElementTy()) &&
878+
assert(!mlir::cast<cir::ComplexType>(result.getType()).isIntegerComplex() &&
880879
"integral complex will never be promoted");
881880
return builder.createCast(cir::CastKind::float_complex, result,
882881
convertType(PromotionType));
883882
}
884883

885884
mlir::Value CIRGenFunction::emitUnPromotedValue(mlir::Value result,
886885
QualType UnPromotionType) {
887-
assert(mlir::isa<cir::CIRFPTypeInterface>(
888-
mlir::cast<cir::ComplexType>(result.getType()).getElementTy()) &&
886+
assert(!mlir::cast<cir::ComplexType>(result.getType()).isIntegerComplex() &&
889887
"integral complex will never be promoted");
890888
return builder.createCast(cir::CastKind::float_complex, result,
891889
convertType(UnPromotionType));

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,12 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
357357
LogicalResult ComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
358358
cir::ComplexType type, mlir::TypedAttr real,
359359
mlir::TypedAttr imag) {
360-
auto elemTy = type.getElementTy();
361-
if (real.getType() != elemTy) {
360+
auto elemType = type.getElementType();
361+
if (real.getType() != elemType) {
362362
emitError() << "type of the real part does not match the complex type";
363363
return failure();
364364
}
365-
if (imag.getType() != elemTy) {
365+
if (imag.getType() != elemType) {
366366
emitError() << "type of the imaginary part does not match the complex type";
367367
return failure();
368368
}

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+29-34
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ LogicalResult cir::CastOp::verify() {
655655
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
656656
if (!resComplexTy)
657657
return emitOpError() << "requires !cir.complex type for result";
658-
if (srcType != resComplexTy.getElementTy())
658+
if (srcType != resComplexTy.getElementType())
659659
return emitOpError() << "requires source type match result element type";
660660
return success();
661661
}
@@ -665,7 +665,7 @@ LogicalResult cir::CastOp::verify() {
665665
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
666666
if (!resComplexTy)
667667
return emitOpError() << "requires !cir.complex type for result";
668-
if (srcType != resComplexTy.getElementTy())
668+
if (srcType != resComplexTy.getElementType())
669669
return emitOpError() << "requires source type match result element type";
670670
return success();
671671
}
@@ -675,7 +675,7 @@ LogicalResult cir::CastOp::verify() {
675675
return emitOpError() << "requires !cir.complex type for source";
676676
if (!mlir::isa<cir::CIRFPTypeInterface>(resType))
677677
return emitOpError() << "requires !cir.float type for result";
678-
if (srcComplexTy.getElementTy() != resType)
678+
if (srcComplexTy.getElementType() != resType)
679679
return emitOpError() << "requires source element type match result type";
680680
return success();
681681
}
@@ -685,71 +685,66 @@ LogicalResult cir::CastOp::verify() {
685685
return emitOpError() << "requires !cir.complex type for source";
686686
if (!mlir::isa<cir::IntType>(resType))
687687
return emitOpError() << "requires !cir.int type for result";
688-
if (srcComplexTy.getElementTy() != resType)
688+
if (srcComplexTy.getElementType() != resType)
689689
return emitOpError() << "requires source element type match result type";
690690
return success();
691691
}
692692
case cir::CastKind::float_complex_to_bool: {
693693
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
694-
if (!srcComplexTy ||
695-
!mlir::isa<cir::CIRFPTypeInterface>(srcComplexTy.getElementTy()))
694+
if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
696695
return emitOpError()
697-
<< "requires !cir.complex<!cir.float> type for source";
696+
<< "requires floating point !cir.complex type for source";
698697
if (!mlir::isa<cir::BoolType>(resType))
699698
return emitOpError() << "requires !cir.bool type for result";
700699
return success();
701700
}
702701
case cir::CastKind::int_complex_to_bool: {
703702
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
704-
if (!srcComplexTy || !mlir::isa<cir::IntType>(srcComplexTy.getElementTy()))
703+
if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
705704
return emitOpError()
706-
<< "requires !cir.complex<!cir.float> type for source";
705+
<< "requires floating point !cir.complex type for source";
707706
if (!mlir::isa<cir::BoolType>(resType))
708707
return emitOpError() << "requires !cir.bool type for result";
709708
return success();
710709
}
711710
case cir::CastKind::float_complex: {
712711
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
713-
if (!srcComplexTy ||
714-
!mlir::isa<cir::CIRFPTypeInterface>(srcComplexTy.getElementTy()))
712+
if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
715713
return emitOpError()
716-
<< "requires !cir.complex<!cir.float> type for source";
714+
<< "requires floating point !cir.complex type for source";
717715
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
718-
if (!resComplexTy ||
719-
!mlir::isa<cir::CIRFPTypeInterface>(resComplexTy.getElementTy()))
716+
if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
720717
return emitOpError()
721-
<< "requires !cir.complex<!cir.float> type for result";
718+
<< "requires floating point !cir.complex type for result";
722719
return success();
723720
}
724721
case cir::CastKind::float_complex_to_int_complex: {
725722
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
726-
if (!srcComplexTy ||
727-
!mlir::isa<cir::CIRFPTypeInterface>(srcComplexTy.getElementTy()))
723+
if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex())
728724
return emitOpError()
729-
<< "requires !cir.complex<!cir.float> type for source";
725+
<< "requires floating point !cir.complex type for source";
730726
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
731-
if (!resComplexTy || !mlir::isa<cir::IntType>(resComplexTy.getElementTy()))
732-
return emitOpError() << "requires !cir.complex<!cir.int> type for result";
727+
if (!resComplexTy || !resComplexTy.isIntegerComplex())
728+
return emitOpError() << "requires integer !cir.complex type for result";
733729
return success();
734730
}
735731
case cir::CastKind::int_complex: {
736732
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
737-
if (!srcComplexTy || !mlir::isa<cir::IntType>(srcComplexTy.getElementTy()))
738-
return emitOpError() << "requires !cir.complex<!cir.int> type for source";
733+
if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
734+
return emitOpError() << "requires integer !cir.complex type for source";
739735
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
740-
if (!resComplexTy || !mlir::isa<cir::IntType>(resComplexTy.getElementTy()))
741-
return emitOpError() << "requires !cir.complex<!cir.int> type for result";
736+
if (!resComplexTy || !resComplexTy.isIntegerComplex())
737+
return emitOpError() << "requires integer !cir.complex type for result";
742738
return success();
743739
}
744740
case cir::CastKind::int_complex_to_float_complex: {
745741
auto srcComplexTy = mlir::dyn_cast<cir::ComplexType>(srcType);
746-
if (!srcComplexTy || !mlir::isa<cir::IntType>(srcComplexTy.getElementTy()))
747-
return emitOpError() << "requires !cir.complex<!cir.int> type for source";
742+
if (!srcComplexTy || !srcComplexTy.isIntegerComplex())
743+
return emitOpError() << "requires integer !cir.complex type for source";
748744
auto resComplexTy = mlir::dyn_cast<cir::ComplexType>(resType);
749-
if (!resComplexTy ||
750-
!mlir::isa<cir::CIRFPTypeInterface>(resComplexTy.getElementTy()))
745+
if (!resComplexTy || !resComplexTy.isFloatingPointComplex())
751746
return emitOpError()
752-
<< "requires !cir.complex<!cir.float> type for result";
747+
<< "requires floating point !cir.complex type for result";
753748
return success();
754749
}
755750
case cir::CastKind::member_ptr_to_bool: {
@@ -912,7 +907,7 @@ LogicalResult cir::DerivedMethodOp::verify() {
912907
//===----------------------------------------------------------------------===//
913908

914909
LogicalResult cir::ComplexCreateOp::verify() {
915-
if (getType().getElementTy() != getReal().getType()) {
910+
if (getType().getElementType() != getReal().getType()) {
916911
emitOpError()
917912
<< "operand type of cir.complex.create does not match its result type";
918913
return failure();
@@ -945,7 +940,7 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
945940
//===----------------------------------------------------------------------===//
946941

947942
LogicalResult cir::ComplexRealOp::verify() {
948-
if (getType() != getOperand().getType().getElementTy()) {
943+
if (getType() != getOperand().getType().getElementType()) {
949944
emitOpError() << "cir.complex.real result type does not match operand type";
950945
return failure();
951946
}
@@ -960,7 +955,7 @@ OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
960955
}
961956

962957
LogicalResult cir::ComplexImagOp::verify() {
963-
if (getType() != getOperand().getType().getElementTy()) {
958+
if (getType() != getOperand().getType().getElementType()) {
964959
emitOpError() << "cir.complex.imag result type does not match operand type";
965960
return failure();
966961
}
@@ -984,7 +979,7 @@ LogicalResult cir::ComplexRealPtrOp::verify() {
984979
auto operandPointeeTy =
985980
mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
986981

987-
if (resultPointeeTy != operandPointeeTy.getElementTy()) {
982+
if (resultPointeeTy != operandPointeeTy.getElementType()) {
988983
emitOpError()
989984
<< "cir.complex.real_ptr result type does not match operand type";
990985
return failure();
@@ -999,7 +994,7 @@ LogicalResult cir::ComplexImagPtrOp::verify() {
999994
auto operandPointeeTy =
1000995
mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
1001996

1002-
if (resultPointeeTy != operandPointeeTy.getElementTy()) {
997+
if (resultPointeeTy != operandPointeeTy.getElementType()) {
1003998
emitOpError()
1004999
<< "cir.complex.imag_ptr result type does not match operand type";
10051000
return failure();

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+2-16
Original file line numberDiff line numberDiff line change
@@ -798,18 +798,6 @@ bool cir::isIntOrIntVectorTy(mlir::Type t) {
798798
// ComplexType Definitions
799799
//===----------------------------------------------------------------------===//
800800

801-
mlir::LogicalResult cir::ComplexType::verify(
802-
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
803-
mlir::Type elementTy) {
804-
if (!mlir::isa<cir::IntType, cir::CIRFPTypeInterface>(elementTy)) {
805-
emitError() << "element type of !cir.complex must be either a "
806-
"floating-point type or an integer type";
807-
return failure();
808-
}
809-
810-
return success();
811-
}
812-
813801
llvm::TypeSize
814802
cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
815803
mlir::DataLayoutEntryListRef params) const {
@@ -818,8 +806,7 @@ cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
818806
// as an array type containing exactly two elements of the corresponding
819807
// real type.
820808

821-
auto elementTy = getElementTy();
822-
return dataLayout.getTypeSizeInBits(elementTy) * 2;
809+
return dataLayout.getTypeSizeInBits(getElementType()) * 2;
823810
}
824811

825812
uint64_t
@@ -830,8 +817,7 @@ cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
830817
// as an array type containing exactly two elements of the corresponding
831818
// real type.
832819

833-
auto elementTy = getElementTy();
834-
return dataLayout.getTypeABIAlignment(elementTy);
820+
return dataLayout.getTypeABIAlignment(getElementType());
835821
}
836822

837823
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)