Skip to content

Commit 75363f7

Browse files
sbrantqwsmoses
andauthored
Logging for error estimate (#1859)
* Add logging func call * Add test * Add original value * func & bb name (requires -fno-discard-value-names) * improve * indices * improve * use std::distance instead * fix private method call * improve * fix format * add counter test * Add test eq mechanism --------- Co-authored-by: William S. Moses <[email protected]>
1 parent db5d616 commit 75363f7

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

enzyme/test/Integration/ForwardError/binops.c

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,24 @@ double fabs(double);
1111

1212
extern double __enzyme_error_estimate(void *, ...);
1313

14+
int errorLogCount = 0;
15+
16+
void enzymeLogError(double res, double err, const char *opcodeName,
17+
const char *calleeName, const char *moduleName,
18+
const char *functionName, const char *blockName) {
19+
++errorLogCount;
20+
printf("Res = %e, Error = %e, Op = %s, Callee = %s, Module = %s, Function = "
21+
"%s, BasicBlock = %s\n",
22+
res, err, opcodeName, calleeName, moduleName, functionName, blockName);
23+
}
24+
1425
// An example from https://dl.acm.org/doi/10.1145/3371128
1526
double fun(double x) {
1627
double v1 = cos(x);
1728
double v2 = 1 - v1;
1829
double v3 = x * x;
1930
double v4 = v2 / v3;
20-
double v5 = sin(v4);
31+
double v5 = sin(v4); // Inactive -- logger is not invoked.
2132

2233
printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2,
2334
v3, v4, v5);
@@ -31,4 +42,5 @@ int main() {
3142
printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error,
3243
fabs(error / res));
3344
APPROX_EQ(error, 2.2222222222e-2, 1e-4);
45+
TEST_EQ(errorLogCount, 4);
3446
}

enzyme/test/Integration/test_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,12 @@ static inline bool approx_fp_equality_double(double f1, double f2, double thresh
5353
abort(); \
5454
} \
5555
};
56+
57+
#define TEST_EQ(LHS, RHS) \
58+
{ \
59+
if ((LHS) != (RHS)) {\
60+
fprintf(stderr, "Assertion Failed: [%s = %d] != [%s = %d] at %s:%d (%s)\n", #LHS, (int)(LHS), #RHS, (int)(RHS), \
61+
__FILE__, __LINE__, __PRETTY_FUNCTION__); \
62+
abort(); \
63+
} \
64+
};

enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,91 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
21692169
<< origName << ")), res);\n";
21702170

21712171
os << " assert(res);\n";
2172+
2173+
// Insert logging function call (optional)
2174+
os << " Function *logFunc = " << origName
2175+
<< ".getModule()->getFunction(\"enzymeLogError\");\n";
2176+
os << " if (logFunc) {\n"
2177+
<< " std::string moduleName = " << origName
2178+
<< ".getModule()->getModuleIdentifier() ;\n"
2179+
<< " std::string functionName = " << origName
2180+
<< ".getFunction()->getName().str();\n"
2181+
<< " std::string blockName = " << origName
2182+
<< ".getParent()->getName().str();\n"
2183+
<< " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n"
2184+
<< " auto funcIt = std::find_if(" << origName
2185+
<< ".getModule()->begin(), " << origName
2186+
<< ".getModule()->end(),\n"
2187+
" [&](const auto& func) { return &func == "
2188+
<< origName
2189+
<< ".getFunction(); });\n"
2190+
" if (funcIt != "
2191+
<< origName
2192+
<< ".getModule()->end()) {\n"
2193+
" funcIdx = "
2194+
"std::distance("
2195+
<< origName << ".getModule()->begin(), funcIt);\n"
2196+
<< " }\n"
2197+
<< " auto blockIt = std::find_if(" << origName
2198+
<< ".getFunction()->begin(), " << origName
2199+
<< ".getFunction()->end(),\n"
2200+
" [&](const auto& block) { return &block == "
2201+
<< origName
2202+
<< ".getParent(); });\n"
2203+
" if (blockIt != "
2204+
<< origName
2205+
<< ".getFunction()->end()) {\n"
2206+
" blockIdx = std::distance("
2207+
<< origName << ".getFunction()->begin(), blockIt);\n"
2208+
<< " }\n"
2209+
<< " auto instIt = std::find_if(" << origName
2210+
<< ".getParent()->begin(), " << origName
2211+
<< ".getParent()->end(),\n"
2212+
" [&](const auto& curr) { return &curr == &"
2213+
<< origName
2214+
<< "; });\n"
2215+
" if (instIt != "
2216+
<< origName
2217+
<< ".getParent()->end()) {\n"
2218+
" instIdx = std::distance("
2219+
<< origName << ".getParent()->begin(), instIt);\n"
2220+
<< " }\n"
2221+
<< " Value *origValue = "
2222+
"Builder2.CreateFPExt(gutils->getNewFromOriginal(&"
2223+
<< origName << "), Type::getDoubleTy(" << origName
2224+
<< ".getContext()));\n"
2225+
<< " Value *errValue = Builder2.CreateFPExt(res, "
2226+
"Type::getDoubleTy("
2227+
<< origName << ".getContext()));\n"
2228+
<< " std::string opcodeName = " << origName
2229+
<< ".getOpcodeName();\n"
2230+
<< " std::string calleeName = \"<N/A>\";\n"
2231+
<< " if (auto CI = dyn_cast<CallInst>(&" << origName
2232+
<< ")) {\n"
2233+
<< " if (Function *fn = CI->getCalledFunction()) {\n"
2234+
<< " calleeName = fn->getName();\n"
2235+
<< " } else {\n"
2236+
<< " calleeName = \"<Unknown>\";\n"
2237+
<< " }\n"
2238+
<< " }\n"
2239+
<< " Value *moduleNameValue = "
2240+
"Builder2.CreateGlobalStringPtr(moduleName);\n"
2241+
<< " Value *functionNameValue = "
2242+
"Builder2.CreateGlobalStringPtr(functionName + \" (\" +"
2243+
"std::to_string(funcIdx) + \")\");\n"
2244+
<< " Value *blockNameValue = "
2245+
"Builder2.CreateGlobalStringPtr(blockName + \" (\" +"
2246+
"std::to_string(blockIdx) + \")\");\n"
2247+
<< " Value *opcodeNameValue = "
2248+
"Builder2.CreateGlobalStringPtr(opcodeName + \" (\" "
2249+
"+std::to_string(instIdx) + \")\");\n"
2250+
<< " Value *calleeNameValue = "
2251+
"Builder2.CreateGlobalStringPtr(calleeName);\n"
2252+
<< " Builder2.CreateCall(logFunc, {origValue, "
2253+
"errValue, opcodeNameValue, calleeNameValue, moduleNameValue, "
2254+
"functionNameValue, blockNameValue});\n"
2255+
<< " }\n";
2256+
21722257
os << " setDiffe(&" << origName << ", res, Builder2);\n";
21732258
os << " break;\n";
21742259
os << " }\n";

0 commit comments

Comments
 (0)