@@ -2169,6 +2169,91 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
21692169 << origName << " )), res);\n " ;
21702170
21712171 os << " assert(res);\n " ;
2172+
2173+ // Insert logging function call (optional)
2174+ os << " Function *logFunc = " << origName
2175+ << " .getModule()->getFunction(\" enzymeLogError\" );\n " ;
2176+ os << " if (logFunc) {\n "
2177+ << " std::string moduleName = " << origName
2178+ << " .getModule()->getModuleIdentifier() ;\n "
2179+ << " std::string functionName = " << origName
2180+ << " .getFunction()->getName().str();\n "
2181+ << " std::string blockName = " << origName
2182+ << " .getParent()->getName().str();\n "
2183+ << " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n "
2184+ << " auto funcIt = std::find_if(" << origName
2185+ << " .getModule()->begin(), " << origName
2186+ << " .getModule()->end(),\n "
2187+ " [&](const auto& func) { return &func == "
2188+ << origName
2189+ << " .getFunction(); });\n "
2190+ " if (funcIt != "
2191+ << origName
2192+ << " .getModule()->end()) {\n "
2193+ " funcIdx = "
2194+ " std::distance("
2195+ << origName << " .getModule()->begin(), funcIt);\n "
2196+ << " }\n "
2197+ << " auto blockIt = std::find_if(" << origName
2198+ << " .getFunction()->begin(), " << origName
2199+ << " .getFunction()->end(),\n "
2200+ " [&](const auto& block) { return &block == "
2201+ << origName
2202+ << " .getParent(); });\n "
2203+ " if (blockIt != "
2204+ << origName
2205+ << " .getFunction()->end()) {\n "
2206+ " blockIdx = std::distance("
2207+ << origName << " .getFunction()->begin(), blockIt);\n "
2208+ << " }\n "
2209+ << " auto instIt = std::find_if(" << origName
2210+ << " .getParent()->begin(), " << origName
2211+ << " .getParent()->end(),\n "
2212+ " [&](const auto& curr) { return &curr == &"
2213+ << origName
2214+ << " ; });\n "
2215+ " if (instIt != "
2216+ << origName
2217+ << " .getParent()->end()) {\n "
2218+ " instIdx = std::distance("
2219+ << origName << " .getParent()->begin(), instIt);\n "
2220+ << " }\n "
2221+ << " Value *origValue = "
2222+ " Builder2.CreateFPExt(gutils->getNewFromOriginal(&"
2223+ << origName << " ), Type::getDoubleTy(" << origName
2224+ << " .getContext()));\n "
2225+ << " Value *errValue = Builder2.CreateFPExt(res, "
2226+ " Type::getDoubleTy("
2227+ << origName << " .getContext()));\n "
2228+ << " std::string opcodeName = " << origName
2229+ << " .getOpcodeName();\n "
2230+ << " std::string calleeName = \" <N/A>\" ;\n "
2231+ << " if (auto CI = dyn_cast<CallInst>(&" << origName
2232+ << " )) {\n "
2233+ << " if (Function *fn = CI->getCalledFunction()) {\n "
2234+ << " calleeName = fn->getName();\n "
2235+ << " } else {\n "
2236+ << " calleeName = \" <Unknown>\" ;\n "
2237+ << " }\n "
2238+ << " }\n "
2239+ << " Value *moduleNameValue = "
2240+ " Builder2.CreateGlobalStringPtr(moduleName);\n "
2241+ << " Value *functionNameValue = "
2242+ " Builder2.CreateGlobalStringPtr(functionName + \" (\" +"
2243+ " std::to_string(funcIdx) + \" )\" );\n "
2244+ << " Value *blockNameValue = "
2245+ " Builder2.CreateGlobalStringPtr(blockName + \" (\" +"
2246+ " std::to_string(blockIdx) + \" )\" );\n "
2247+ << " Value *opcodeNameValue = "
2248+ " Builder2.CreateGlobalStringPtr(opcodeName + \" (\" "
2249+ " +std::to_string(instIdx) + \" )\" );\n "
2250+ << " Value *calleeNameValue = "
2251+ " Builder2.CreateGlobalStringPtr(calleeName);\n "
2252+ << " Builder2.CreateCall(logFunc, {origValue, "
2253+ " errValue, opcodeNameValue, calleeNameValue, moduleNameValue, "
2254+ " functionNameValue, blockNameValue});\n "
2255+ << " }\n " ;
2256+
21722257 os << " setDiffe(&" << origName << " , res, Builder2);\n " ;
21732258 os << " break;\n " ;
21742259 os << " }\n " ;
0 commit comments