Skip to content

Commit b5b81f8

Browse files
authored
[TypeAnalysis] improve memtransfer error handler (#1588)
* [TypeAnalysis] improve memtransfer error handler * also add better binop error
1 parent 48cbcca commit b5b81f8

File tree

3 files changed

+158
-30
lines changed

3 files changed

+158
-30
lines changed

enzyme/Enzyme/TypeAnalysis/ConcreteType.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ class ConcreteType {
292292
/// Set this to the logical `binop` of itself and RHS, using the Binop Op,
293293
/// returning true if this was changed.
294294
/// This function will error on an invalid type combination
295-
bool binopIn(const ConcreteType RHS, llvm::BinaryOperator::BinaryOps Op) {
295+
bool binopIn(bool &Legal, const ConcreteType RHS,
296+
llvm::BinaryOperator::BinaryOps Op) {
296297
bool Changed = false;
297298
using namespace llvm;
298299

@@ -366,7 +367,8 @@ class ConcreteType {
366367
// No change since we retain data from LHS
367368
break;
368369
default:
369-
llvm_unreachable("unknown binary operator");
370+
Legal = false;
371+
return Changed;
370372
}
371373
return Changed;
372374
}
@@ -404,10 +406,9 @@ class ConcreteType {
404406
case BinaryOperator::Shl:
405407
case BinaryOperator::AShr:
406408
case BinaryOperator::LShr:
407-
llvm_unreachable("illegal pointer/pointer operation");
408-
break;
409409
default:
410-
llvm_unreachable("unknown binary operator");
410+
Legal = false;
411+
return Changed;
411412
}
412413
return Changed;
413414
}
@@ -468,7 +469,8 @@ class ConcreteType {
468469
case BinaryOperator::URem:
469470
case BinaryOperator::SRem:
470471
if (RHS.SubTypeEnum == BaseType::Pointer) {
471-
llvm_unreachable("cannot divide integer by pointer");
472+
Legal = false;
473+
return Changed;
472474
} else if (SubTypeEnum != BaseType::Unknown) {
473475
SubTypeEnum = BaseType::Unknown;
474476
Changed = true;
@@ -486,14 +488,14 @@ class ConcreteType {
486488
}
487489
break;
488490
default:
489-
llvm_unreachable("unknown binary operator");
491+
Legal = false;
492+
return Changed;
490493
}
491494
return Changed;
492495
}
493496

494-
llvm::errs() << "self: " << str() << " RHS: " << RHS.str() << " Op: " << Op
495-
<< "\n";
496-
llvm_unreachable("Unknown ConcreteType::binopIn");
497+
Legal = false;
498+
return Changed;
497499
}
498500

499501
/// Compare concrete types for use in map's

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 132 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,12 +2071,57 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) {
20712071
if (BO->getOperand(0) == &phi) {
20722072
set = true;
20732073
PhiTypes = otherData;
2074-
PhiTypes.binopIn(getAnalysis(BO->getOperand(1)), BO->getOpcode());
2074+
bool Legal = true;
2075+
PhiTypes.binopIn(Legal, getAnalysis(BO->getOperand(1)),
2076+
BO->getOpcode());
2077+
if (!Legal) {
2078+
std::string str;
2079+
raw_string_ostream ss(str);
2080+
if (!CustomErrorHandler) {
2081+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2082+
llvm::errs() << *fntypeinfo.Function << "\n";
2083+
dump(ss);
2084+
}
2085+
ss << "Illegal updateBinop Analysis " << *BO << "\n";
2086+
ss << "Illegal binopIn(0): " << *BO
2087+
<< " lhs: " << PhiTypes.str()
2088+
<< " rhs: " << getAnalysis(BO->getOperand(0)).str() << "\n";
2089+
if (CustomErrorHandler) {
2090+
CustomErrorHandler(str.c_str(), wrap(BO),
2091+
ErrorType::IllegalTypeAnalysis,
2092+
(void *)this, wrap(BO), nullptr);
2093+
}
2094+
EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO,
2095+
ss.str());
2096+
report_fatal_error("Performed illegal updateAnalysis");
2097+
}
20752098
break;
20762099
} else if (BO->getOperand(1) == &phi) {
20772100
set = true;
20782101
PhiTypes = getAnalysis(BO->getOperand(0));
2079-
PhiTypes.binopIn(otherData, BO->getOpcode());
2102+
bool Legal = true;
2103+
PhiTypes.binopIn(Legal, otherData, BO->getOpcode());
2104+
if (!Legal) {
2105+
std::string str;
2106+
raw_string_ostream ss(str);
2107+
if (!CustomErrorHandler) {
2108+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2109+
llvm::errs() << *fntypeinfo.Function << "\n";
2110+
dump(ss);
2111+
}
2112+
ss << "Illegal updateBinop Analysis " << *BO << "\n";
2113+
ss << "Illegal binopIn(1): " << *BO
2114+
<< " lhs: " << PhiTypes.str() << " rhs: " << otherData.str()
2115+
<< "\n";
2116+
if (CustomErrorHandler) {
2117+
CustomErrorHandler(str.c_str(), wrap(BO),
2118+
ErrorType::IllegalTypeAnalysis,
2119+
(void *)this, wrap(BO), nullptr);
2120+
}
2121+
EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO,
2122+
ss.str());
2123+
report_fatal_error("Performed illegal updateAnalysis");
2124+
}
20802125
break;
20812126
}
20822127
} else if (BO->getOpcode() == BinaryOperator::Sub) {
@@ -2124,7 +2169,27 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) {
21242169
TypeTree vd2 = isa<Constant>(bo->getOperand(1))
21252170
? getAnalysis(bo->getOperand(1)).Data0()
21262171
: PhiTypes.Data0();
2127-
vd1.binopIn(vd2, bo->getOpcode());
2172+
bool Legal = true;
2173+
vd1.binopIn(Legal, vd2, bo->getOpcode());
2174+
if (!Legal) {
2175+
std::string str;
2176+
raw_string_ostream ss(str);
2177+
if (!CustomErrorHandler) {
2178+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2179+
llvm::errs() << *fntypeinfo.Function << "\n";
2180+
dump(ss);
2181+
}
2182+
ss << "Illegal updateBinop Analysis " << *bo << "\n";
2183+
ss << "Illegal binopIn(consts): " << *bo << " lhs: " << vd1.str()
2184+
<< " rhs: " << vd2.str() << "\n";
2185+
if (CustomErrorHandler) {
2186+
CustomErrorHandler(str.c_str(), wrap(bo),
2187+
ErrorType::IllegalTypeAnalysis, (void *)this,
2188+
wrap(bo), nullptr);
2189+
}
2190+
EmitFailure("IllegalUpdateAnalysis", bo->getDebugLoc(), bo, ss.str());
2191+
report_fatal_error("Performed illegal updateAnalysis");
2192+
}
21282193
PhiTypes &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0, &phi);
21292194
}
21302195

@@ -2999,8 +3064,28 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
29993064

30003065
if (direction & DOWN) {
30013066
TypeTree Result = AnalysisLHS;
3002-
Result.binopIn(AnalysisRHS, Opcode);
3003-
3067+
bool Legal = true;
3068+
Result.binopIn(Legal, AnalysisRHS, Opcode);
3069+
if (!Legal) {
3070+
std::string str;
3071+
raw_string_ostream ss(str);
3072+
if (!CustomErrorHandler) {
3073+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3074+
llvm::errs() << *fntypeinfo.Function << "\n";
3075+
dump(ss);
3076+
}
3077+
ss << "Illegal updateBinop Analysis " << *origin << "\n";
3078+
ss << "Illegal binopIn(down): " << Opcode << " lhs: " << Result.str()
3079+
<< " rhs: " << AnalysisRHS.str() << "\n";
3080+
if (CustomErrorHandler) {
3081+
CustomErrorHandler(str.c_str(), wrap(origin),
3082+
ErrorType::IllegalTypeAnalysis, (void *)this,
3083+
wrap(origin), nullptr);
3084+
}
3085+
EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), origin,
3086+
ss.str());
3087+
report_fatal_error("Performed illegal updateAnalysis");
3088+
}
30043089
if (Opcode == BinaryOperator::And) {
30053090
for (int i = 0; i < 2; ++i) {
30063091
if (Args[i])
@@ -3254,16 +3339,27 @@ void TypeAnalyzer::visitMemTransferCommon(llvm::CallBase &MTI) {
32543339
bool Legal = true;
32553340
res.checkedOrIn(res2, /*PointerIntSame*/ false, Legal);
32563341
if (!Legal) {
3257-
dump();
3258-
llvm::errs() << MTI << "\n";
3259-
llvm::errs() << "Illegal orIn: " << res.str() << " right: " << res2.str()
3260-
<< "\n";
3261-
llvm::errs() << *MTI.getArgOperand(0) << " "
3262-
<< getAnalysis(MTI.getArgOperand(0)).str() << "\n";
3263-
llvm::errs() << *MTI.getArgOperand(1) << " "
3264-
<< getAnalysis(MTI.getArgOperand(1)).str() << "\n";
3265-
assert(0 && "Performed illegal visitMemTransferInst::orIn");
3266-
llvm_unreachable("Performed illegal visitMemTransferInst::orIn");
3342+
std::string str;
3343+
raw_string_ostream ss(str);
3344+
if (!CustomErrorHandler) {
3345+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3346+
llvm::errs() << *fntypeinfo.Function << "\n";
3347+
dump(ss);
3348+
}
3349+
ss << "Illegal updateMemTransfer Analysis " << MTI << "\n";
3350+
ss << "Illegal orIn: " << res.str() << " right: " << res2.str() << "\n";
3351+
ss << *MTI.getArgOperand(0) << " "
3352+
<< getAnalysis(MTI.getArgOperand(0)).str() << "\n";
3353+
ss << *MTI.getArgOperand(1) << " "
3354+
<< getAnalysis(MTI.getArgOperand(1)).str() << "\n";
3355+
3356+
if (CustomErrorHandler) {
3357+
CustomErrorHandler(str.c_str(), wrap(&MTI),
3358+
ErrorType::IllegalTypeAnalysis, (void *)this,
3359+
wrap(&MTI), nullptr);
3360+
}
3361+
EmitFailure("IllegalUpdateAnalysis", MTI.getDebugLoc(), &MTI, ss.str());
3362+
report_fatal_error("Performed illegal updateAnalysis");
32673363
}
32683364
res.insert({}, BaseType::Pointer);
32693365
res = res.Only(-1, &MTI);
@@ -3810,8 +3906,27 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
38103906
updateAnalysis(I.getOperand(1), analysis.Only(-1, &I), &I);
38113907

38123908
TypeTree vd = getAnalysis(I.getOperand(0)).Data0();
3813-
vd.binopIn(getAnalysis(I.getOperand(1)).Data0(), opcode);
3814-
3909+
bool Legal = true;
3910+
vd.binopIn(Legal, getAnalysis(I.getOperand(1)).Data0(), opcode);
3911+
if (!Legal) {
3912+
std::string str;
3913+
raw_string_ostream ss(str);
3914+
if (!CustomErrorHandler) {
3915+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3916+
llvm::errs() << *fntypeinfo.Function << "\n";
3917+
dump(ss);
3918+
}
3919+
ss << "Illegal updateBinopIntr Analysis " << I << "\n";
3920+
ss << "Illegal binopIn(intr): " << I << " lhs: " << vd.str()
3921+
<< " rhs: " << getAnalysis(I.getOperand(1)).str() << "\n";
3922+
if (CustomErrorHandler) {
3923+
CustomErrorHandler(str.c_str(), wrap(&I),
3924+
ErrorType::IllegalTypeAnalysis, (void *)this,
3925+
wrap(&I), nullptr);
3926+
}
3927+
EmitFailure("IllegalUpdateAnalysis", I.getDebugLoc(), &I, ss.str());
3928+
report_fatal_error("Performed illegal updateAnalysis");
3929+
}
38153930
auto &dl = I.getParent()->getParent()->getParent()->getDataLayout();
38163931
int sz = (dl.getTypeSizeInBits(I.getOperand(0)->getType()) + 7) / 8;
38173932
TypeTree overall = vd.Only(-1, &I).ShiftIndices(dl, 0, sz, 0);

enzyme/Enzyme/TypeAnalysis/TypeTree.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,8 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
11121112
/// Set this to the logical `binop` of itself and RHS, using the Binop Op,
11131113
/// returning true if this was changed.
11141114
/// This function will error on an invalid type combination
1115-
bool binopIn(const TypeTree &RHS, llvm::BinaryOperator::BinaryOps Op) {
1115+
bool binopIn(bool &Legal, const TypeTree &RHS,
1116+
llvm::BinaryOperator::BinaryOps Op) {
11161117
bool changed = false;
11171118

11181119
for (auto &pair : llvm::make_early_inc_range(mapping)) {
@@ -1134,7 +1135,12 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
11341135
RightCT = found->second;
11351136
}
11361137

1137-
changed |= CT.binopIn(RightCT, Op);
1138+
bool SubLegal = true;
1139+
changed |= CT.binopIn(SubLegal, RightCT, Op);
1140+
if (!SubLegal) {
1141+
Legal = false;
1142+
return changed;
1143+
}
11381144
if (CT == BaseType::Unknown) {
11391145
mapping.erase(pair.first);
11401146
} else {
@@ -1154,7 +1160,12 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
11541160

11551161
if (mapping.find(pair.first) == RHS.mapping.end()) {
11561162
ConcreteType CT = BaseType::Unknown;
1157-
changed |= CT.binopIn(pair.second, Op);
1163+
bool SubLegal = true;
1164+
changed |= CT.binopIn(SubLegal, pair.second, Op);
1165+
if (!SubLegal) {
1166+
Legal = false;
1167+
return changed;
1168+
}
11581169
if (CT != BaseType::Unknown) {
11591170
mapping.insert(std::make_pair(pair.first, CT));
11601171
}

0 commit comments

Comments
 (0)