Skip to content

Commit c0c1070

Browse files
authored
Change compile time or type analysis err into runtime (#1713)
1 parent 55d67c3 commit c0c1070

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

enzyme/Enzyme/InstructionDerivatives.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,15 @@ def : CallPattern<(Op (Op $x, $y):$z),
823823
[ReadNone, NoUnwind]
824824
>;
825825

826+
def : CallPattern<(Op (Op $x, $y):$z),
827+
["cmplx_inv"],
828+
[
829+
(CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)),
830+
],
831+
(ForwardFromSummedReverse),
832+
[ReadNone, NoUnwind]
833+
>;
834+
826835
def : IntrPattern<(Op $x),
827836
[["sin"]],
828837
[(FMul (DiffeRet), (Intrinsic<"cos"> $x))] ,

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
157157

158158
{"__fd_sincos_1", Intrinsic::not_intrinsic},
159159
{"sincospi", Intrinsic::not_intrinsic},
160+
{"cmplx_inv", Intrinsic::not_intrinsic},
160161

161162
// bessel functions
162163
{"j0", Intrinsic::not_intrinsic},
@@ -2937,7 +2938,32 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
29372938
// If ^ against 0b10000000000, the result is a float
29382939
bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl);
29392940
if (validXor) {
2940-
((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr);
2941+
bool Legal = true;
2942+
((i == 0) ? RHS : LHS)
2943+
.checkedOrIn(TypeTree(FT).Only(-1, nullptr),
2944+
/*pointerintsame*/ false, Legal);
2945+
2946+
if (!Legal) {
2947+
std::string str;
2948+
raw_string_ostream ss(str);
2949+
if (!CustomErrorHandler) {
2950+
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2951+
llvm::errs() << *fntypeinfo.Function << "\n";
2952+
dump(ss);
2953+
}
2954+
ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n";
2955+
ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " "
2956+
<< ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT
2957+
<< "\n";
2958+
if (CustomErrorHandler) {
2959+
CustomErrorHandler(str.c_str(), wrap(origin),
2960+
ErrorType::IllegalTypeAnalysis, (void *)this,
2961+
wrap(origin), nullptr);
2962+
}
2963+
EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(),
2964+
origin, ss.str());
2965+
report_fatal_error("Performed illegal updateAnalysis");
2966+
}
29412967
}
29422968
}
29432969
break;

0 commit comments

Comments
 (0)