diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 42720adb9194a..48ae8756f82cf 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -73,14 +73,6 @@ bool check_for_requirement(sycl::detail::AccessorImplHost *Req, } } // anonymous namespace -void exec_graph_impl::schedule() { - if (MSchedule.empty()) { - for (auto Node : MGraphImpl->MRoots) { - Node->topology_sort(Node, MSchedule); - } - } -} - std::shared_ptr graph_impl::add_subgraph_nodes( const std::list> &NodeList) { // Find all input and output nodes from the node list @@ -104,6 +96,15 @@ std::shared_ptr graph_impl::add_subgraph_nodes( return this->add(Outputs); } +std::list> graph_impl::compute_schedule() { + exec_order_recompute(); + std::list> Sched; + for (auto &Next : MExecOrder) { + Sched.push_back(Next.second); + } + return Sched; +}; + void graph_impl::add_root(const std::shared_ptr &Root) { MRoots.insert(Root); } @@ -571,7 +572,6 @@ command_graph::command_graph( void command_graph::finalize_impl() { // Create PI command-buffers for each device in the finalized context - impl->schedule(); auto Context = impl->get_context(); for (auto Device : impl->get_context().get_devices()) { diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 8b23d254e359d..1a5cc3a54589b 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { @@ -75,22 +77,29 @@ class node_impl { std::unique_ptr &&CommandGroup) : MCGType(CGType), MCommandGroup(std::move(CommandGroup)) {} - /// Recursively add nodes to execution stack. - /// @param NodeImpl Node to schedule. - /// @param Schedule Execution ordering to add node to. - void topology_sort(std::shared_ptr NodeImpl, - std::list> &Schedule) { - for (auto Next : MSuccessors) { - // Check if we've already scheduled this node - if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end()) - Next->topology_sort(Next, Schedule); - } - // We don't need to schedule empty nodes as they are only used when - // calculating dependencies - if (!NodeImpl->is_empty()) - Schedule.push_front(NodeImpl); - } +private: + /// Depth of this node in a containing graph. + /// + /// The first call to graph.exec_order_recompute computes & caches the value. + /// It will likely become stale whenever the containing graph is changed and + /// a single value will be inadequate if this node is added to multiple graphs. + /// Caching is dangerous but recomputing takes O(graph_size) worst-case time. + std::optional MDepth; +public: + /// Gets the depth of this node in its containing graph. + /// @return the depth of this node in its containing graph. + int get_depth() { + if (!MDepth.has_value()) { + int MaxDepthFound = -1; + for (auto &P : MPredecessors) { + MaxDepthFound = std::max(MaxDepthFound, P.lock()->get_depth()); + } + MDepth = MaxDepthFound + 1; + } + return MDepth.value(); + }; + /// Checks if this node has a given requirement. /// @param Requirement Requirement to lookup. /// @return True if \p Requirement is present in node, false otherwise. @@ -180,7 +189,7 @@ class node_impl { } }; -/// Class representing implementation details of command_graph. +/// Class representing implementation details of modifiable command_graph. class graph_impl { public: /// Constructor. @@ -190,6 +199,39 @@ class graph_impl { : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(), MEventsMap() {} +private: + /// A sorted multimap capturing a breadth-first execution/submission order. + /// + /// The SortKey is the depth in the graph for the node_impl in the value. + /// Depth is the length of the longest dependence chain to any root node. + std::multimap> MExecOrder; + + /// Depth-first recursion from V to build the execution order. + /// @param V Starting node for depth-first recursion. + void exec_order_recompute(node_impl &V) { + // depth-first recursion to access all nodes that succeed this node + for (auto &S : V.MSuccessors) { + exec_order_recompute(*S.get()); + } + // insert this into execution order based on its depth in the graph + MExecOrder.insert(std::pair(V.get_depth(), &V)); + }; + + /// Recomputes the submission/execution order for this whole graph. + void exec_order_recompute() { + MExecOrder.clear(); + // for all root nodes ... + for (auto &Root : MRoots) { + // ... recurse towards all exit nodes + exec_order_recompute(*Root); + } + }; + +public: + /// Recomputes the submission/execution order then schedules all nodes. + /// @return A list of shared pointers to nodes in linear scheduling order. + std::list> compute_schedule(); + /// Insert node into list of root nodes. /// @param Root Node to add to list of root nodes. void add_root(const std::shared_ptr &Root); @@ -313,7 +355,8 @@ class exec_graph_impl { /// @param GraphImpl Modifiable graph implementation to create with. exec_graph_impl(sycl::context Context, const std::shared_ptr &GraphImpl) - : MSchedule(), MGraphImpl(GraphImpl), MPiCommandBuffers(), + : MSchedule(GraphImpl->compute_schedule()), + MPiCommandBuffers(), MPiSyncPoints(), MContext(Context) {} /// Destructor. @@ -321,9 +364,6 @@ class exec_graph_impl { /// Releases any PI command-buffers the object has created. ~exec_graph_impl(); - /// Add nodes to MSchedule. - void schedule(); - /// Called by handler::ext_oneapi_command_graph() to schedule graph for /// execution. /// @param Queue Command-queue to schedule execution on. @@ -378,9 +418,6 @@ class exec_graph_impl { /// Execution schedule of nodes in the graph. std::list> MSchedule; - /// Pointer to the modifiable graph impl associated with this executable - /// graph. - std::shared_ptr MGraphImpl; /// Map of devices to command buffers. std::unordered_map MPiCommandBuffers; /// Map of nodes in the exec graph to the sync point representing their