From 0a3d62bf60eadabfd9405fd4b865695911fedbac Mon Sep 17 00:00:00 2001 From: Giovanni Date: Sat, 7 Dec 2024 13:43:13 +0100 Subject: [PATCH 1/3] [df] JIT graph creation functions once, call them many times --- .../dataframe/inc/ROOT/RDF/InterfaceUtils.hxx | 34 +--- .../dataframe/inc/ROOT/RDF/RInterfaceBase.hxx | 9 +- tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx | 26 +++ tree/dataframe/src/RDFInterfaceUtils.cxx | 158 ++++++++---------- tree/dataframe/src/RLoopManager.cxx | 74 +++++++- 5 files changed, 180 insertions(+), 121 deletions(-) diff --git a/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx b/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx index 104d41c9a12cc..742dc48c952a0 100644 --- a/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx +++ b/tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx @@ -342,10 +342,9 @@ BookVariationJit(const std::vector &colNames, std::string_view vari RDataSource *ds, const RColumnRegister &colRegister, const ColumnNames_t &branches, std::shared_ptr *upcastNodeOnHeap, bool isSingleColumn); -std::string JitBuildAction(const ColumnNames_t &bl, std::shared_ptr *prevNode, - const std::type_info &art, const std::type_info &at, void *rOnHeap, TTree *tree, +std::string JitBuildAction(const ColumnNames_t &bl, const std::type_info &art, const std::type_info &at, TTree *tree, const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds, - std::weak_ptr *jittedActionOnHeap, const bool vector2RVec = true); + const bool vector2RVec = true); // Allocate a weak_ptr on the heap, return a pointer to it. The user is responsible for deleting this weak_ptr. // This function is meant to be used by RInterface's methods that book code for jitting. @@ -420,7 +419,7 @@ void AddDSColumns(const std::vector &requiredCols, RLoopManager &lm // this function is meant to be called by the jitted code generated by BookFilterJit template -void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name, +void JitFilterHelper(F &&f, const ColumnNames_t &cols, std::string_view name, std::weak_ptr *wkJittedFilter, std::shared_ptr *prevNodeOnHeap, RColumnRegister *colRegister) noexcept { @@ -433,9 +432,6 @@ void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str return; } - const ColumnNames_t cols(colsPtr, colsPtr + colsSize); - delete[] colsPtr; - const auto jittedFilter = wkJittedFilter->lock(); // mock Filter logic -- validity checks and Define-ition of RDataSource columns @@ -485,7 +481,7 @@ auto MakeDefineNode(DefineTypes::RDefinePerSampleTag, std::string_view name, std // This function is meant to be called by jitted code right before starting the event loop. // If colsPtr is null, build a RDefinePerSample (it has no input columns), otherwise a RDefine. template -void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name, RLoopManager *lm, +void JitDefineHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RLoopManager *lm, std::weak_ptr *wkJittedDefine, RColumnRegister *colRegister, std::shared_ptr *prevNodeOnHeap) noexcept { @@ -494,7 +490,6 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str delete wkJittedDefine; delete colRegister; delete prevNodeOnHeap; - delete[] colsPtr; }; if (wkJittedDefine->expired()) { @@ -504,8 +499,6 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str return; } - const ColumnNames_t cols(colsPtr, colsPtr + colsSize); - auto jittedDefine = wkJittedDefine->lock(); using Callable_t = std::decay_t; @@ -527,18 +520,14 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str } template -void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const char **variedCols, - std::size_t variedColsSize, const char **variationTags, std::size_t variationTagsSize, - std::string_view variationName, RLoopManager *lm, - std::weak_ptr *wkJittedVariation, RColumnRegister *colRegister, - std::shared_ptr *prevNodeOnHeap) noexcept +void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, const ColumnNames_t &variedColNames, + const char **variationTags, std::size_t variationTagsSize, std::string_view variationName, + RLoopManager *lm, std::weak_ptr *wkJittedVariation, + RColumnRegister *colRegister, std::shared_ptr *prevNodeOnHeap) noexcept { // a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code auto doDeletes = [&] { - delete[] colsPtr; - delete[] variedCols; delete[] variationTags; - delete wkJittedVariation; delete colRegister; delete prevNodeOnHeap; @@ -551,8 +540,6 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const return; } - const ColumnNames_t inputColNames(colsPtr, colsPtr + colsSize); - std::vector variedColNames(variedCols, variedCols + variedColsSize); std::vector tags(variationTags, variationTags + variationTagsSize); auto jittedVariation = wkJittedVariation->lock(); @@ -575,13 +562,12 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const /// Convenience function invoked by jitted code to build action nodes at runtime template -void CallBuildAction(std::shared_ptr *prevNodeOnHeap, const char **colsPtr, std::size_t colsSize, +void CallBuildAction(std::shared_ptr *prevNodeOnHeap, const ColumnNames_t &cols, const unsigned int nSlots, std::shared_ptr *helperArgOnHeap, std::weak_ptr *wkJittedActionOnHeap, RColumnRegister *colRegister) noexcept { // a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code auto doDeletes = [&] { - delete[] colsPtr; delete helperArgOnHeap; delete wkJittedActionOnHeap; // colRegister must be deleted before prevNodeOnHeap because their dtor needs the RLoopManager to be alive @@ -597,8 +583,6 @@ void CallBuildAction(std::shared_ptr *prevNodeOnHeap, const char * return; } - const ColumnNames_t cols(colsPtr, colsPtr + colsSize); - auto jittedActionOnHeap = wkJittedActionOnHeap->lock(); // if we are here it means we are jitting, if we are jitting the loop manager must be alive diff --git a/tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx b/tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx index 74a0cf7556cdf..4346e5bfc7bab 100644 --- a/tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx @@ -189,10 +189,11 @@ protected: fColRegister, proxiedPtr->GetVariations()); auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction); - auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType), - typeid(ActionTag), helperArgOnHeap, tree, nSlots, fColRegister, - fDataSource, jittedActionOnHeap, vector2RVec); - fLoopManager->ToJitExec(toJit); + auto definesCopy = new RDFInternal::RColumnRegister(fColRegister); // deleted in jitted call + auto funcBody = RDFInternal::JitBuildAction(validColumnNames, typeid(HelperArgType), typeid(ActionTag), tree, + nSlots, fColRegister, fDataSource, vector2RVec); + fLoopManager->RegisterJitHelperCall(funcBody, upcastNodeOnHeap, definesCopy, validColumnNames, jittedActionOnHeap, + helperArgOnHeap); return MakeResultPtr(r, *fLoopManager, std::move(jittedAction)); } diff --git a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx index 31967a245cde2..314ec5fb55410 100644 --- a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx @@ -50,6 +50,7 @@ class RActionBase; class RVariationBase; class RDefinesWithReaders; class RVariationsWithReaders; +class RColumnRegister; namespace GraphDrawing { class GraphCreatorHelper; @@ -192,6 +193,27 @@ class RLoopManager : public RNodeBase { std::set>> fUniqueVariationsWithReaders; + // deferred function calls to Jitted functions + struct DeferredJitCall { + std::string functionId; + std::shared_ptr *prevNodeOnHeap; + ROOT::Internal::RDF::RColumnRegister *colRegister; + std::vector colNames; + void *wkJittedNode, *argument; + DeferredJitCall(const std::string &id, std::shared_ptr *prevNode, + ROOT::Internal::RDF::RColumnRegister *cols, const std::vector &colnames, + void *wkNodePtr, void *arg) + : functionId(id), + prevNodeOnHeap(prevNode), + colRegister(cols), + colNames(colnames), + wkJittedNode(wkNodePtr), + argument(arg) + { + } + }; + std::vector fJitHelperCalls; + public: RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches); RLoopManager(std::unique_ptr tree, const ColumnNames_t &defaultBranches); @@ -209,6 +231,7 @@ public: void JitDeclarations(); void Jit(); + void RunDeferredCalls(); RLoopManager *GetLoopManagerUnchecked() final { return this; } void Run(bool jit = true); const ColumnNames_t &GetDefaultColumnNames() const; @@ -235,6 +258,9 @@ public: void IncrChildrenCount() final { ++fNChildren; } void StopProcessing() final { ++fNStopsReceived; } void ToJitExec(const std::string &) const; + void RegisterJitHelperCall(const std::string &funcBody, std::shared_ptr *prevNodeOnHeap, + ROOT::Internal::RDF::RColumnRegister *colRegister, + const std::vector &colNames, void *wkJittedPtr, void *argument = nullptr); void RegisterCallback(ULong64_t everyNEvents, std::function &&f); unsigned int GetNRuns() const { return fNRuns; } bool HasDataSourceColumnReaders(const std::string &col, const std::type_info &ti) const; diff --git a/tree/dataframe/src/RDFInterfaceUtils.cxx b/tree/dataframe/src/RDFInterfaceUtils.cxx index 773288f465bf6..aa3942366e369 100644 --- a/tree/dataframe/src/RDFInterfaceUtils.cxx +++ b/tree/dataframe/src/RDFInterfaceUtils.cxx @@ -17,6 +17,7 @@ #include #include #include +#include "ROOT/RLogger.hxx" #include #include #include @@ -670,35 +671,29 @@ BookFilterJit(std::shared_ptr *prevNodeOnHeap, std::string // definesOnHeap is deleted by the jitted call to JitFilterHelper ROOT::Internal::RDF::RColumnRegister *definesOnHeap = new ROOT::Internal::RDF::RColumnRegister(colRegister); - const auto definesOnHeapAddr = PrettyPrintAddr(definesOnHeap); - const auto prevNodeAddr = PrettyPrintAddr(prevNodeOnHeap); const auto jittedFilter = std::make_shared( (*prevNodeOnHeap)->GetLoopManagerUnchecked(), name, Union(colRegister.GetVariationDeps(parsedExpr.fUsedCols), (*prevNodeOnHeap)->GetVariations())); // Produce code snippet that creates the filter and registers it with the corresponding RJittedFilter - // Windows requires std::hex << std::showbase << (size_t)pointer to produce notation "0x1234" - std::stringstream filterInvocation; - filterInvocation << "ROOT::Internal::RDF::JitFilterHelper(" << funcName << ", new const char*[" - << parsedExpr.fUsedCols.size() << "]{"; - for (const auto &col : parsedExpr.fUsedCols) - filterInvocation << "\"" << col << "\", "; - if (!parsedExpr.fUsedCols.empty()) - filterInvocation.seekp(-2, filterInvocation.cur); // remove the last ", // lifetime of pointees: // - jittedFilter: heap-allocated weak_ptr to the actual jittedFilter that will be deleted by JitFilterHelper // - prevNodeOnHeap: heap-allocated shared_ptr to the actual previous node that will be deleted by JitFilterHelper // - definesOnHeap: heap-allocated, will be deleted by JitFilterHelper - filterInvocation << "}, " << parsedExpr.fUsedCols.size() << ", \"" << name << "\", " - << "reinterpret_cast*>(" - << PrettyPrintAddr(MakeWeakOnHeap(jittedFilter)) << "), " - << "reinterpret_cast*>(" << prevNodeAddr << ")," - << "reinterpret_cast(" << definesOnHeapAddr << ")" - << ");\n"; - + std::stringstream filterInvocation; + filterInvocation << "(ROOT::Detail::RDF::RLoopManager *lm, " + << "std::shared_ptr *prevNodeOnHeap," + << "ROOT::Internal::RDF::RColumnRegister* colRegister, " + << "const std::vector & colNames, " + << "void *wkJittedFilter, void *) {\n"; + filterInvocation << " ROOT::Internal::RDF::JitFilterHelper(" << funcName << ", " + << " colNames, \"" << name << "\", " + << "reinterpret_cast*>(wkJittedFilter)," + << "prevNodeOnHeap, colRegister);\n}\n"; auto lm = jittedFilter->GetLoopManagerUnchecked(); - lm->ToJitExec(filterInvocation.str()); + lm->RegisterJitHelperCall(filterInvocation.str(), prevNodeOnHeap, definesOnHeap, parsedExpr.fUsedCols, + MakeWeakOnHeap(jittedFilter)); return jittedFilter; } @@ -719,30 +714,28 @@ std::shared_ptr BookDefineJit(std::string_view name, std::string_ const auto type = RetTypeOfFunc(funcName); auto definesCopy = new RColumnRegister(colRegister); - auto definesAddr = PrettyPrintAddr(definesCopy); auto jittedDefine = std::make_shared(name, type, lm, colRegister, parsedExpr.fUsedCols); - std::stringstream defineInvocation; - defineInvocation << "ROOT::Internal::RDF::JitDefineHelper(" << funcName - << ", new const char*[" << parsedExpr.fUsedCols.size() << "]{"; - for (const auto &col : parsedExpr.fUsedCols) { - defineInvocation << "\"" << col << "\", "; - } - if (!parsedExpr.fUsedCols.empty()) - defineInvocation.seekp(-2, defineInvocation.cur); // remove the last ", // lifetime of pointees: // - lm is the loop manager, and if that goes out of scope jitting does not happen at all (i.e. will always be valid) // - jittedDefine: heap-allocated weak_ptr that will be deleted by JitDefineHelper after usage // - definesAddr: heap-allocated, will be deleted by JitDefineHelper after usage - defineInvocation << "}, " << parsedExpr.fUsedCols.size() << ", \"" << name - << "\", reinterpret_cast(" << PrettyPrintAddr(&lm) - << "), reinterpret_cast*>(" - << PrettyPrintAddr(MakeWeakOnHeap(jittedDefine)) - << "), reinterpret_cast(" << definesAddr - << "), reinterpret_cast*>(" - << PrettyPrintAddr(upcastNodeOnHeap) << "));\n"; - - lm.ToJitExec(defineInvocation.str()); + std::stringstream defineInvocation; + defineInvocation << "(ROOT::Detail::RDF::RLoopManager *lm, " + << "std::shared_ptr *prevNodeOnHeap," + << "ROOT::Internal::RDF::RColumnRegister* colRegister, " + << "const std::vector & colNames, " + << "void *wkJittedDefine, void *) {\n"; + defineInvocation << "ROOT::Internal::RDF::JitDefineHelper(" << funcName + << ", colNames, \"" << name << "\", " + << "lm, " + << "reinterpret_cast*>(wkJittedDefine)," + << "colRegister, " + << "prevNodeOnHeap);\n}\n"; + + lm.RegisterJitHelperCall(defineInvocation.str(), upcastNodeOnHeap, definesCopy, parsedExpr.fUsedCols, + MakeWeakOnHeap(jittedDefine)); + return jittedDefine; } @@ -759,21 +752,21 @@ std::shared_ptr BookDefinePerSampleJit(std::string_view name, std auto definesAddr = PrettyPrintAddr(definesCopy); auto jittedDefine = std::make_shared(name, retType, lm, colRegister, ColumnNames_t{}); - std::stringstream defineInvocation; - defineInvocation << "ROOT::Internal::RDF::JitDefineHelper(" - << funcName << ", nullptr, 0, "; // lifetime of pointees: // - lm is the loop manager, and if that goes out of scope jitting does not happen at all (i.e. will always be valid) // - jittedDefine: heap-allocated weak_ptr that will be deleted by JitDefineHelper after usage // - definesAddr: heap-allocated, will be deleted by JitDefineHelper after usage - defineInvocation << "\"" << name << "\", reinterpret_cast(" << PrettyPrintAddr(&lm) - << "), reinterpret_cast*>(" - << PrettyPrintAddr(MakeWeakOnHeap(jittedDefine)) - << "), reinterpret_cast(" << definesAddr - << "), reinterpret_cast*>(" - << PrettyPrintAddr(upcastNodeOnHeap) << "));\n"; - - lm.ToJitExec(defineInvocation.str()); + std::stringstream defineInvocation; + defineInvocation << "(ROOT::Detail::RDF::RLoopManager *lm, " + << "std::shared_ptr *prevNodeOnHeap," + << "ROOT::Internal::RDF::RColumnRegister* colRegister, " + << "const std::vector & colNames, " + << "void *wkJittedDefine, void *) {\n"; + defineInvocation << "ROOT::Internal::RDF::JitDefineHelper(" + << funcName << ", colNames, \"" << name << "\", lm, " + << "reinterpret_cast*>(wkJittedDefine), " + << "colRegister, prevNodeOnHeap);\n}\n"; + lm.RegisterJitHelperCall(defineInvocation.str(), upcastNodeOnHeap, definesCopy, {}, MakeWeakOnHeap(jittedDefine)); return jittedDefine; } @@ -806,50 +799,43 @@ BookVariationJit(const std::vector &colNames, std::string_view vari const auto colRegisterAddr = PrettyPrintAddr(colRegisterCopy); auto jittedVariation = std::make_shared(colNames, variationName, variationTags, type, colRegister, lm, parsedExpr.fUsedCols); + auto variedColsOnHeap = new ColumnNames_t(colNames); // build invocation to JitVariationHelper - // arrays of strings are passed as const char** plus size. + // variation tag (array of strings) passed as const char** plus size. // lifetime of pointees: // - lm is the loop manager, and if that goes out of scope jitting does not happen at all (i.e. will always be valid) // - jittedVariation: heap-allocated weak_ptr that will be deleted by JitDefineHelper after usage // - definesAddr: heap-allocated, will be deleted by JitDefineHelper after usage + // - variedColsOnHeap: deleted by registration function std::stringstream varyInvocation; + varyInvocation << "(ROOT::Detail::RDF::RLoopManager *lm, " + << "std::shared_ptr *prevNodeOnHeap," + << "ROOT::Internal::RDF::RColumnRegister* colRegister, " + << "const std::vector & inputColNames, " + << "void *wkJittedVariation, void *variedColsOnHeap) {\n"; + varyInvocation << "auto * variedColNames = reinterpret_cast*>(variedColsOnHeap);\n"; varyInvocation << "ROOT::Internal::RDF::JitVariationHelper<" << (isSingleColumn ? "true" : "false") << ">(" - << funcName << ", new const char*[" << parsedExpr.fUsedCols.size() << "]{"; - for (const auto &col : parsedExpr.fUsedCols) { - varyInvocation << "\"" << col << "\", "; - } - if (!parsedExpr.fUsedCols.empty()) - varyInvocation.seekp(-2, varyInvocation.cur); // remove the last ", " - varyInvocation << "}, " << parsedExpr.fUsedCols.size(); - varyInvocation << ", new const char*[" << colNames.size() << "]{"; - for (const auto &col : colNames) { - varyInvocation << "\"" << col << "\", "; - } - varyInvocation.seekp(-2, varyInvocation.cur); // remove the last ", " - varyInvocation << "}, " << colNames.size() << ", new const char*[" << variationTags.size() << "]{"; + << funcName << ", inputColNames, *variedColNames, "; + varyInvocation << "new const char*[" << variationTags.size() << "]{"; for (const auto &tag : variationTags) { varyInvocation << "\"" << tag << "\", "; } varyInvocation.seekp(-2, varyInvocation.cur); // remove the last ", " - varyInvocation << "}, " << variationTags.size() << ", \"" << variationName - << "\", reinterpret_cast(" << PrettyPrintAddr(&lm) - << "), reinterpret_cast*>(" - << PrettyPrintAddr(MakeWeakOnHeap(jittedVariation)) - << "), reinterpret_cast(" << colRegisterAddr - << "), reinterpret_cast*>(" - << PrettyPrintAddr(upcastNodeOnHeap) << "));\n"; - - lm.ToJitExec(varyInvocation.str()); + varyInvocation << "}, " << variationTags.size() << ", \"" << variationName << "\", lm, " + << "reinterpret_cast*>(wkJittedVariation)," + << "colRegister, prevNodeOnHeap);\n" + << "delete variedColNames;\n}\n"; + lm.RegisterJitHelperCall(varyInvocation.str(), upcastNodeOnHeap, colRegisterCopy, parsedExpr.fUsedCols, MakeWeakOnHeap(jittedVariation), variedColsOnHeap); return jittedVariation; } // Jit and call something equivalent to "this->BuildAndBook(params...)" // (see comments in the body for actual jitted code) -std::string JitBuildAction(const ColumnNames_t &cols, std::shared_ptr *prevNode, - const std::type_info &helperArgType, const std::type_info &at, void *helperArgOnHeap, +std::string JitBuildAction(const ColumnNames_t &cols, + const std::type_info &helperArgType, const std::type_info &at, TTree *tree, const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds, - std::weak_ptr *jittedActionOnHeap, const bool vector2RVec) + const bool vector2RVec) { // retrieve type of action as a string auto actionTypeClass = TClass::GetClass(at); @@ -866,30 +852,22 @@ std::string JitBuildAction(const ColumnNames_t &cols, std::shared_ptr *prevNodeOnHeap," + << "ROOT::Internal::RDF::RColumnRegister* colRegister, " + << "const std::vector & colNames, " + << "void *wkJittedAction, void *actionArg) {\n"; createAction_str << "ROOT::Internal::RDF::CallBuildAction<" << actionTypeName; const auto columnTypeNames = GetValidatedArgTypes(cols, colRegister, tree, ds, actionTypeNameBase, vector2RVec); for (auto &colType : columnTypeNames) createAction_str << ", " << colType; - // on Windows, to prefix the hexadecimal value of a pointer with '0x', - // one need to write: std::hex << std::showbase << (size_t)pointer - createAction_str << ">(reinterpret_cast*>(" - << PrettyPrintAddr(prevNode) << "), new const char*[" << cols.size() << "]{"; - for (auto i = 0u; i < cols.size(); ++i) { - if (i != 0u) - createAction_str << ", "; - createAction_str << '"' << cols[i] << '"'; - } - createAction_str << "}, " << cols.size() << ", " << nSlots << ", reinterpret_cast*>(" << PrettyPrintAddr(helperArgOnHeap) - << "), reinterpret_cast*>(" - << PrettyPrintAddr(jittedActionOnHeap) - << "), reinterpret_cast(" << definesAddr << "));"; + createAction_str << ">(prevNodeOnHeap, colNames," << nSlots << ", " + << " reinterpret_cast*>(actionArg)," + << " reinterpret_cast*>(wkJittedAction)," + << "colRegister);\n}\n"; return createAction_str.str(); } diff --git a/tree/dataframe/src/RLoopManager.cxx b/tree/dataframe/src/RLoopManager.cxx index f0fae786055c6..0e4ea53c7791c 100644 --- a/tree/dataframe/src/RLoopManager.cxx +++ b/tree/dataframe/src/RLoopManager.cxx @@ -72,6 +72,19 @@ std::string &GetCodeToJit() return code; } +using JitHelperFunc = void(RLoopManager *, std::shared_ptr *, RColumnRegister *, const ColumnNames_t &, + void *, void *); +std::unordered_map &GetJitHelperFuncMap() +{ + static std::unordered_map map; + return map; +} +std::unordered_map &GetJitHelperNameMap() +{ + static std::unordered_map map; + return map; +} + bool ContainsLeaf(const std::set &leaves, TLeaf *leaf) { return (leaves.find(leaf) != leaves.end()); @@ -869,6 +882,26 @@ void RLoopManager::Jit() : " in less than 1ms."); } +void RLoopManager::RunDeferredCalls() +{ + if (!fJitHelperCalls.empty()) { + R__READ_LOCKGUARD(ROOT::gCoreMutex); // methods are thread-safe but funcMap isn't (yet) + TStopwatch s2; + s2.Start(); + auto &funcMap = GetJitHelperFuncMap(); + for (auto &call : fJitHelperCalls) { + assert(funcMap.find(call.functionId) != funcMap.end()); + funcMap[call.functionId](this, call.prevNodeOnHeap, call.colRegister, call.colNames, call.wkJittedNode, + call.argument); + } + s2.Stop(); + R__LOG_INFO(RDFLogChannel()) << "Deferred calls (" << fJitHelperCalls.size() << ") completed" + << (s2.RealTime() > 1e-3 ? " in " + std::to_string(s2.RealTime()) + " seconds." + : " in less than 1ms."); + fJitHelperCalls.clear(); + } +} + /// Trigger counting of number of children nodes for each node of the functional graph. /// This is done once before starting the event loop. Each action sends an `increase children count` signal /// upstream, which is propagated until RLoopManager. Each time a node receives the signal, in increments its @@ -891,13 +924,13 @@ void RLoopManager::Run(bool jit) // Change value of TTree::GetMaxTreeSize only for this scope. Revert when #6640 will be solved. MaxTreeSizeRAII ctxtmts; - R__LOG_INFO(RDFLogChannel()) << "Starting event loop number " << fNRuns << '.'; - ThrowIfNSlotsChanged(GetNSlots()); if (jit) Jit(); + RunDeferredCalls(); + InitNodes(); // Exceptions can occur during the event loop. In order to ensure proper cleanup of nodes @@ -1030,6 +1063,43 @@ void RLoopManager::ToJitExec(const std::string &code) const GetCodeToJit().append(code); } +void RLoopManager::RegisterJitHelperCall(const std::string &funcCode, std::shared_ptr *prevNodeOnHeap, + ROOT::Internal::RDF::RColumnRegister *colRegister, + const std::vector &colNames, void *wkJittedNode, void *argument) +{ + auto &nameMap = GetJitHelperNameMap(); + auto &funcMap = GetJitHelperFuncMap(); + { + R__READ_LOCKGUARD(ROOT::gCoreMutex); + auto match = nameMap.find(funcCode); + if (match != nameMap.end()) { + R__LOG_DEBUG(0, RDFLogChannel()) << "JitHelper " << match->second << " already defined"; + fJitHelperCalls.emplace_back(match->second, prevNodeOnHeap, colRegister, colNames, wkJittedNode, argument); + return; + } + } + + { + R__WRITE_LOCKGUARD(ROOT::gCoreMutex); + std::string registerId = "jitNodeRegistrator_" + std::to_string(nameMap.size()); + nameMap[funcCode] = registerId; + R__LOG_DEBUG(0, RDFLogChannel()) << "JitHelper new " << registerId << " defined for funcCode " << funcCode; + // step 1: register function (now) + std::string toDeclare = "namespace R_rdf {\n void " + registerId + funcCode + "\n}\n"; + ROOT::Internal::RDF::InterpreterDeclare(toDeclare); + std::stringstream registration; + registration + << "(*(reinterpret_cast " + "*,ROOT::Internal::RDF::RColumnRegister *, const std::vector & colNames, void*,void*)>*>("; + registration << PrettyPrintAddr((void *)(&funcMap)); + registration << ")))[\"" << registerId << "\"] = R_rdf::" << registerId << ";\n"; + std::string registrationStr = registration.str(); + GetCodeToJit().append(registrationStr); + fJitHelperCalls.emplace_back(registerId, prevNodeOnHeap, colRegister, colNames, wkJittedNode, argument); + } +} + void RLoopManager::RegisterCallback(ULong64_t everyNEvents, std::function &&f) { if (everyNEvents == 0ull) From 3867133479b24ca7533b57bbfe596f1b02eb70c1 Mon Sep 17 00:00:00 2001 From: Giovanni Date: Sun, 8 Dec 2024 23:31:21 +0100 Subject: [PATCH 2/3] [df] Gather all declarations and JIT them at once to reduce overhead --- tree/dataframe/src/RLoopManager.cxx | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tree/dataframe/src/RLoopManager.cxx b/tree/dataframe/src/RLoopManager.cxx index 0e4ea53c7791c..9aa322b89bdc5 100644 --- a/tree/dataframe/src/RLoopManager.cxx +++ b/tree/dataframe/src/RLoopManager.cxx @@ -71,6 +71,11 @@ std::string &GetCodeToJit() static std::string code; return code; } +std::string &GetCodeToDeclare() +{ + static std::string code; + return code; +} using JitHelperFunc = void(RLoopManager *, std::shared_ptr *, RColumnRegister *, const ColumnNames_t &, void *, void *); @@ -862,19 +867,22 @@ void RLoopManager::Jit() { { R__READ_LOCKGUARD(ROOT::gCoreMutex); - if (GetCodeToJit().empty()) { + if (GetCodeToJit().empty() && GetCodeToDeclare().empty()) { R__LOG_INFO(RDFLogChannel()) << "Nothing to jit and execute."; return; } } - const std::string code = []() { + std::string codeToDeclare, code; + { R__WRITE_LOCKGUARD(ROOT::gCoreMutex); - return std::move(GetCodeToJit()); - }(); + codeToDeclare.swap(GetCodeToDeclare()); + code.swap(GetCodeToJit()); + }; TStopwatch s; s.Start(); + ROOT::Internal::RDF::InterpreterDeclare(codeToDeclare); RDFInternal::InterpreterCalc(code, "RLoopManager::Run"); s.Stop(); R__LOG_INFO(RDFLogChannel()) << "Just-in-time compilation phase completed" @@ -1086,7 +1094,7 @@ void RLoopManager::RegisterJitHelperCall(const std::string &funcCode, std::share R__LOG_DEBUG(0, RDFLogChannel()) << "JitHelper new " << registerId << " defined for funcCode " << funcCode; // step 1: register function (now) std::string toDeclare = "namespace R_rdf {\n void " + registerId + funcCode + "\n}\n"; - ROOT::Internal::RDF::InterpreterDeclare(toDeclare); + GetCodeToDeclare().append(toDeclare); std::stringstream registration; registration << "(*(reinterpret_cast Date: Tue, 10 Dec 2024 11:16:49 +0100 Subject: [PATCH 3/3] [df] Speed up function pointer lookup --- tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx | 1 - tree/dataframe/src/RLoopManager.cxx | 41 ++++++++++++++------ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx index 314ec5fb55410..feee936c94903 100644 --- a/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx @@ -229,7 +229,6 @@ public: RLoopManager &operator=(RLoopManager &&) = delete; ~RLoopManager() = default; - void JitDeclarations(); void Jit(); void RunDeferredCalls(); RLoopManager *GetLoopManagerUnchecked() final { return this; } diff --git a/tree/dataframe/src/RLoopManager.cxx b/tree/dataframe/src/RLoopManager.cxx index 9aa322b89bdc5..9c6707c0f02e6 100644 --- a/tree/dataframe/src/RLoopManager.cxx +++ b/tree/dataframe/src/RLoopManager.cxx @@ -27,6 +27,7 @@ #include "TEntryList.h" #include "TFile.h" #include "TFriendElement.h" +#include "TInterpreter.h" #include "TROOT.h" // IsImplicitMTEnabled, gCoreMutex, R__*_LOCKGUARD #include "TTreeReader.h" #include "TTree.h" // For MaxTreeSizeRAII. Revert when #6640 will be solved. @@ -868,6 +869,7 @@ void RLoopManager::Jit() { R__READ_LOCKGUARD(ROOT::gCoreMutex); if (GetCodeToJit().empty() && GetCodeToDeclare().empty()) { + RunDeferredCalls(); R__LOG_INFO(RDFLogChannel()) << "Nothing to jit and execute."; return; } @@ -882,12 +884,37 @@ void RLoopManager::Jit() TStopwatch s; s.Start(); - ROOT::Internal::RDF::InterpreterDeclare(codeToDeclare); - RDFInternal::InterpreterCalc(code, "RLoopManager::Run"); + if (!codeToDeclare.empty()) { + ROOT::Internal::RDF::InterpreterDeclare(codeToDeclare); + auto &funcMap = GetJitHelperFuncMap(); + auto &nameMap = GetJitHelperNameMap(); + auto clinfo = gInterpreter->ClassInfo_Factory("R_rdf"); + assert(gInterpreter->ClassInfo_IsValid(clinfo)); + for (auto & codeAndName : nameMap) { + JitHelperFunc * & addr = funcMap[codeAndName.second]; + if (!addr) { + // fast fetch of the address via gInterpreter + // (faster than gInterpreter->Evaluate(function name, ret), ret->GetAsPointer()) + auto declid = gInterpreter->GetFunction(clinfo, codeAndName.second.c_str()); + assert(declid); + auto minfo = gInterpreter->MethodInfo_Factory(declid); + assert(gInterpreter->MethodInfo_IsValid(minfo)); + auto mname = gInterpreter->MethodInfo_GetMangledName(minfo); + addr = reinterpret_cast(gInterpreter->FindSym(mname)); + gInterpreter->MethodInfo_Delete(minfo); + } + } + gInterpreter->ClassInfo_Delete(clinfo); + } + if (!code.empty()) { + RDFInternal::InterpreterCalc(code, "RLoopManager::Run"); + } s.Stop(); R__LOG_INFO(RDFLogChannel()) << "Just-in-time compilation phase completed" << (s.RealTime() > 1e-3 ? " in " + std::to_string(s.RealTime()) + " seconds." : " in less than 1ms."); + + RunDeferredCalls(); } void RLoopManager::RunDeferredCalls() @@ -1076,7 +1103,6 @@ void RLoopManager::RegisterJitHelperCall(const std::string &funcCode, std::share const std::vector &colNames, void *wkJittedNode, void *argument) { auto &nameMap = GetJitHelperNameMap(); - auto &funcMap = GetJitHelperFuncMap(); { R__READ_LOCKGUARD(ROOT::gCoreMutex); auto match = nameMap.find(funcCode); @@ -1095,15 +1121,6 @@ void RLoopManager::RegisterJitHelperCall(const std::string &funcCode, std::share // step 1: register function (now) std::string toDeclare = "namespace R_rdf {\n void " + registerId + funcCode + "\n}\n"; GetCodeToDeclare().append(toDeclare); - std::stringstream registration; - registration - << "(*(reinterpret_cast " - "*,ROOT::Internal::RDF::RColumnRegister *, const std::vector & colNames, void*,void*)>*>("; - registration << PrettyPrintAddr((void *)(&funcMap)); - registration << ")))[\"" << registerId << "\"] = R_rdf::" << registerId << ";\n"; - std::string registrationStr = registration.str(); - GetCodeToJit().append(registrationStr); fJitHelperCalls.emplace_back(registerId, prevNodeOnHeap, colRegister, colNames, wkJittedNode, argument); } }