Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[df] [for discussion] JIT graph creation functions only once #17282

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,9 @@ BookVariationJit(const std::vector<std::string> &colNames, std::string_view vari
RDataSource *ds, const RColumnRegister &colRegister, const ColumnNames_t &branches,
std::shared_ptr<RNodeBase> *upcastNodeOnHeap, bool isSingleColumn);

std::string JitBuildAction(const ColumnNames_t &bl, std::shared_ptr<RDFDetail::RNodeBase> *prevNode,
const std::type_info &art, const std::type_info &at, void *rOnHeap, TTree *tree,
std::string JitBuildAction(const ColumnNames_t &bl, const std::type_info &art, const std::type_info &at, TTree *tree,
const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds,
std::weak_ptr<RJittedAction> *jittedActionOnHeap, const bool vector2RVec = true);
const bool vector2RVec = true);

// Allocate a weak_ptr on the heap, return a pointer to it. The user is responsible for deleting this weak_ptr.
// This function is meant to be used by RInterface's methods that book code for jitting.
Expand Down Expand Up @@ -420,7 +419,7 @@ void AddDSColumns(const std::vector<std::string> &requiredCols, RLoopManager &lm

// this function is meant to be called by the jitted code generated by BookFilterJit
template <typename F, typename PrevNode>
void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name,
void JitFilterHelper(F &&f, const ColumnNames_t &cols, std::string_view name,
std::weak_ptr<RJittedFilter> *wkJittedFilter, std::shared_ptr<PrevNode> *prevNodeOnHeap,
RColumnRegister *colRegister) noexcept
{
Expand All @@ -433,9 +432,6 @@ void JitFilterHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);
delete[] colsPtr;

const auto jittedFilter = wkJittedFilter->lock();

// mock Filter logic -- validity checks and Define-ition of RDataSource columns
Expand Down Expand Up @@ -485,7 +481,7 @@ auto MakeDefineNode(DefineTypes::RDefinePerSampleTag, std::string_view name, std
// This function is meant to be called by jitted code right before starting the event loop.
// If colsPtr is null, build a RDefinePerSample (it has no input columns), otherwise a RDefine.
template <typename RDefineTypeTag, typename F>
void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::string_view name, RLoopManager *lm,
void JitDefineHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RLoopManager *lm,
std::weak_ptr<RJittedDefine> *wkJittedDefine, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
{
Expand All @@ -494,7 +490,6 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
delete wkJittedDefine;
delete colRegister;
delete prevNodeOnHeap;
delete[] colsPtr;
};

if (wkJittedDefine->expired()) {
Expand All @@ -504,8 +499,6 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedDefine = wkJittedDefine->lock();

using Callable_t = std::decay_t<F>;
Expand All @@ -527,18 +520,14 @@ void JitDefineHelper(F &&f, const char **colsPtr, std::size_t colsSize, std::str
}

template <bool IsSingleColumn, typename F>
void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const char **variedCols,
std::size_t variedColsSize, const char **variationTags, std::size_t variationTagsSize,
std::string_view variationName, RLoopManager *lm,
std::weak_ptr<RJittedVariation> *wkJittedVariation, RColumnRegister *colRegister,
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, const ColumnNames_t &variedColNames,
const char **variationTags, std::size_t variationTagsSize, std::string_view variationName,
RLoopManager *lm, std::weak_ptr<RJittedVariation> *wkJittedVariation,
RColumnRegister *colRegister, std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete[] variedCols;
delete[] variationTags;

delete wkJittedVariation;
delete colRegister;
delete prevNodeOnHeap;
Expand All @@ -551,8 +540,6 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const
return;
}

const ColumnNames_t inputColNames(colsPtr, colsPtr + colsSize);
std::vector<std::string> variedColNames(variedCols, variedCols + variedColsSize);
std::vector<std::string> tags(variationTags, variationTags + variationTagsSize);

auto jittedVariation = wkJittedVariation->lock();
Expand All @@ -575,13 +562,12 @@ void JitVariationHelper(F &&f, const char **colsPtr, std::size_t colsSize, const

/// Convenience function invoked by jitted code to build action nodes at runtime
template <typename ActionTag, typename... ColTypes, typename PrevNodeType, typename HelperArgType>
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const char **colsPtr, std::size_t colsSize,
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const ColumnNames_t &cols,
const unsigned int nSlots, std::shared_ptr<HelperArgType> *helperArgOnHeap,
std::weak_ptr<RJittedAction> *wkJittedActionOnHeap, RColumnRegister *colRegister) noexcept
{
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
auto doDeletes = [&] {
delete[] colsPtr;
delete helperArgOnHeap;
delete wkJittedActionOnHeap;
// colRegister must be deleted before prevNodeOnHeap because their dtor needs the RLoopManager to be alive
Expand All @@ -597,8 +583,6 @@ void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const char *
return;
}

const ColumnNames_t cols(colsPtr, colsPtr + colsSize);

auto jittedActionOnHeap = wkJittedActionOnHeap->lock();

// if we are here it means we are jitting, if we are jitting the loop manager must be alive
Expand Down
9 changes: 5 additions & 4 deletions tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ protected:
fColRegister, proxiedPtr->GetVariations());
auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);

auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType),
typeid(ActionTag), helperArgOnHeap, tree, nSlots, fColRegister,
fDataSource, jittedActionOnHeap, vector2RVec);
fLoopManager->ToJitExec(toJit);
auto definesCopy = new RDFInternal::RColumnRegister(fColRegister); // deleted in jitted call
auto funcBody = RDFInternal::JitBuildAction(validColumnNames, typeid(HelperArgType), typeid(ActionTag), tree,
nSlots, fColRegister, fDataSource, vector2RVec);
fLoopManager->RegisterJitHelperCall(funcBody, upcastNodeOnHeap, definesCopy, validColumnNames, jittedActionOnHeap,
helperArgOnHeap);
return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
}

Expand Down
27 changes: 26 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RActionBase;
class RVariationBase;
class RDefinesWithReaders;
class RVariationsWithReaders;
class RColumnRegister;

namespace GraphDrawing {
class GraphCreatorHelper;
Expand Down Expand Up @@ -192,6 +193,27 @@ class RLoopManager : public RNodeBase {
std::set<std::pair<std::string_view, std::unique_ptr<ROOT::Internal::RDF::RVariationsWithReaders>>>
fUniqueVariationsWithReaders;

// deferred function calls to Jitted functions
struct DeferredJitCall {
std::string functionId;
std::shared_ptr<RNodeBase> *prevNodeOnHeap;
ROOT::Internal::RDF::RColumnRegister *colRegister;
std::vector<std::string> colNames;
void *wkJittedNode, *argument;
DeferredJitCall(const std::string &id, std::shared_ptr<RNodeBase> *prevNode,
ROOT::Internal::RDF::RColumnRegister *cols, const std::vector<std::string> &colnames,
void *wkNodePtr, void *arg)
: functionId(id),
prevNodeOnHeap(prevNode),
colRegister(cols),
colNames(colnames),
wkJittedNode(wkNodePtr),
argument(arg)
{
}
};
std::vector<DeferredJitCall> fJitHelperCalls;

public:
RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches);
RLoopManager(std::unique_ptr<TTree> tree, const ColumnNames_t &defaultBranches);
Expand All @@ -207,8 +229,8 @@ public:
RLoopManager &operator=(RLoopManager &&) = delete;
~RLoopManager() = default;

void JitDeclarations();
void Jit();
void RunDeferredCalls();
RLoopManager *GetLoopManagerUnchecked() final { return this; }
void Run(bool jit = true);
const ColumnNames_t &GetDefaultColumnNames() const;
Expand All @@ -235,6 +257,9 @@ public:
void IncrChildrenCount() final { ++fNChildren; }
void StopProcessing() final { ++fNStopsReceived; }
void ToJitExec(const std::string &) const;
void RegisterJitHelperCall(const std::string &funcBody, std::shared_ptr<RNodeBase> *prevNodeOnHeap,
ROOT::Internal::RDF::RColumnRegister *colRegister,
const std::vector<std::string> &colNames, void *wkJittedPtr, void *argument = nullptr);
void RegisterCallback(ULong64_t everyNEvents, std::function<void(unsigned int)> &&f);
unsigned int GetNRuns() const { return fNRuns; }
bool HasDataSourceColumnReaders(const std::string &col, const std::type_info &ti) const;
Expand Down
Loading
Loading