Skip to content

[SYCL][Graph] Breadth-first schedule #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: sycl-graph-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ bool check_for_arg(const sycl::detail::ArgDesc &Arg,
}
} // anonymous namespace

void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto Node : MGraphImpl->MRoots) {
Node->topology_sort(Node, MSchedule);
}
}
}

std::shared_ptr<node_impl> graph_impl::add_subgraph_nodes(
const std::list<std::shared_ptr<node_impl>> &NodeList) {
// Find all input and output nodes from the node list
Expand Down Expand Up @@ -564,7 +556,6 @@ command_graph<graph_state::executable>::command_graph(

void command_graph<graph_state::executable>::finalize_impl() {
// Create PI command-buffers for each device in the finalized context
impl->schedule();

auto Context = impl->get_context();
for (auto Device : impl->get_context().get_devices()) {
Expand Down
104 changes: 82 additions & 22 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <functional>
#include <list>
#include <set>
#include <optional>
#include <map>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -73,22 +75,28 @@ class node_impl {
std::unique_ptr<sycl::detail::CG> &&CommandGroup)
: MCGType(CGType), MCommandGroup(std::move(CommandGroup)) {}

/// Recursively add nodes to execution stack.
/// @param NodeImpl Node to schedule.
/// @param Schedule Execution ordering to add node to.
void topology_sort(std::shared_ptr<node_impl> NodeImpl,
std::list<std::shared_ptr<node_impl>> &Schedule) {
for (auto Next : MSuccessors) {
// Check if we've already scheduled this node
if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end())
Next->topology_sort(Next, Schedule);
}
// We don't need to schedule empty nodes as they are only used when
// calculating dependencies
if (!NodeImpl->is_empty())
Schedule.push_front(NodeImpl);
}
private:
/// Depth of this node in a containing graph
///
/// The first call to graph.exec_order_recompute computes & caches the value
/// It will likely become stale whenever the containing graph is changed and
/// a single value will be inequate if this node is added to multiple graphs
/// Caching is dangerous but recomputing takes O(graph_size) worst-case time
std::optional<int> MDepth;

public:
int get_depth(node_impl &V) { return V.get_depth(); };
int get_depth() {
if (!MDepth.has_value()) {
int max_depth_found = -1;
for (auto P : MPredecessors) {
max_depth_found = std::max(max_depth_found, P.lock()->get_depth());
}
MDepth = max_depth_found + 1;
}
return MDepth.value();
};

/// Checks if this node has an argument.
/// @param Arg Argument to lookup.
/// @return True if \p Arg is used in node, false otherwise.
Expand Down Expand Up @@ -197,6 +205,63 @@ class graph_impl {
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap() {}

private:
/// A cache of pointers to exit nodes
///
/// This is not used (yet), but depth computation starts from exit nodes
/// Perhaps, it might be better to do the exec_order_recompute traversal
/// starting from each exit node and working upwards using MPredecessors
/// rather than from each root node and doing depth-first to exit nodes?
std::vector<node_impl *> MExitNodes;

/// A sorted multimap capturing the optimal execution/submission order
///
/// The SortKey is the depth in the graph for the node_impl in the value
/// Depth is the length of the longest dependence chain to any root node
std::multimap<int, std::shared_ptr<node_impl>> MExecOrder;

/// <summary>
/// Depth-first recursion from V to build the optimal execution order
/// </summary>
/// <param name="V">Starting node for depth-first recursion</param>
void exec_order_recompute(node_impl &V) {
// depth-first recursion to access all nodes that succeed this node
for (auto &S : V.MSuccessors) {
exec_order_recompute(*S.get());
}
// insert this into execution order based on its depth in the graph
MExecOrder.insert(std::pair(V.get_depth(), &V));
// cache all the exit nodes; no reason, just feels like a good idea
if (V.MSuccessors.empty()) {
MExitNodes.push_back(&V);
}
};

/// <summary>
/// Recomputes the optimal submission/execution order for this whole graph
/// </summary>
void exec_order_recompute() {
MExecOrder.clear();
// for all root nodes ...
for (auto &root : MRoots) {
// ... recurse towards all exit nodes
exec_order_recompute(*root);
}
};

public:
/// <summary>
/// Recomputes the optimal submission/execution order then schedules all nodes
/// </summary>
std::list<std::shared_ptr<node_impl>> compute_schedule() {
exec_order_recompute();
std::list<std::shared_ptr<node_impl>> sched;
for (auto &next : MExecOrder) {
sched.push_front(*next.second.get());
}
return sched;
};

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void add_root(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -314,17 +379,15 @@ class exec_graph_impl {
/// @param GraphImpl Modifiable graph implementation to create with.
exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl)
: MSchedule(), MGraphImpl(GraphImpl), MPiCommandBuffers(),
: MSchedule(GraphImpl->compute_schedule()),
MPiCommandBuffers(),
MPiSyncPoints(), MContext(Context) {}

/// Destructor.
///
/// Releases any PI command-buffers the object has created.
~exec_graph_impl();

/// Add nodes to MSchedule.
void schedule();

/// Enqueues the backend objects for the graph to the parametrized queue.
/// @param Queue Command-queue to submit backend objects to.
/// @return Event associated with enqueued object.
Expand Down Expand Up @@ -384,9 +447,6 @@ class exec_graph_impl {

/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, RT::PiExtCommandBuffer> MPiCommandBuffers;
/// Map of nodes in the exec graph to the sync point representing their
Expand Down