Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Graph] Refactor node storage inside graphs #334

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 64 additions & 87 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,6 @@ namespace experimental {
namespace detail {

namespace {

/// Recursively check if a given node is an exit node, and add the new nodes as
/// successors if so.
/// @param[in] CurrentNode Node to check as exit node.
/// @param[in] NewInputs Noes to add as successors.
void connectToExitNodes(
std::shared_ptr<node_impl> CurrentNode,
const std::vector<std::shared_ptr<node_impl>> &NewInputs) {
if (CurrentNode->MSuccessors.size() > 0) {
for (auto &Successor : CurrentNode->MSuccessors) {
connectToExitNodes(Successor, NewInputs);
}

} else {
for (auto &Input : NewInputs) {
CurrentNode->registerSuccessor(Input, CurrentNode);
}
}
}

/// Recursive check if a graph node or its successors contains a given
/// requirement.
/// @param[in] Req The requirement to check for.
/// @param[in] CurrentNode The current graph node being checked.
/// @param[in,out] Deps The unique list of dependencies which have been
/// identified for this requirement.
/// @return True if a dependency was added in this node or any of its
/// successors.
bool checkForRequirement(sycl::detail::AccessorImplHost *Req,
const std::shared_ptr<node_impl> &CurrentNode,
std::set<std::shared_ptr<node_impl>> &Deps) {
bool SuccessorAddedDep = false;
for (auto &Successor : CurrentNode->MSuccessors) {
SuccessorAddedDep |= checkForRequirement(Req, Successor, Deps);
}

if (!CurrentNode->isEmpty() && Deps.find(CurrentNode) == Deps.end() &&
CurrentNode->hasRequirement(Req) && !SuccessorAddedDep) {
Deps.insert(CurrentNode);
return true;
}
return SuccessorAddedDep;
}

/// Visits a node on the graph and it's successors recursively in a depth-first
/// approach.
/// @param[in] Node The current node being visited.
Expand All @@ -99,7 +55,8 @@ bool visitNodeDepthFirst(
Node->MVisited = true;
VisitedNodes.emplace(Node);
for (auto &Successor : Node->MSuccessors) {
if (visitNodeDepthFirst(Successor, VisitedNodes, NodeStack, NodeFunc)) {
if (visitNodeDepthFirst(Successor.lock(), VisitedNodes, NodeStack,
NodeFunc)) {
return true;
}
}
Expand All @@ -117,12 +74,28 @@ void duplicateNode(const std::shared_ptr<node_impl> Node,
}
}

/// Recursively add nodes to execution stack.
/// @param NodeImpl Node to schedule.
/// @param Schedule Execution ordering to add node to.
void sortTopological(std::shared_ptr<node_impl> NodeImpl,
std::list<std::shared_ptr<node_impl>> &Schedule) {
for (auto &Succ : NodeImpl->MSuccessors) {
// Check if we've already scheduled this node
auto NextNode = Succ.lock();
if (std::find(Schedule.begin(), Schedule.end(), NextNode) ==
Schedule.end()) {
sortTopological(NextNode, Schedule);
}
}

Schedule.push_front(NodeImpl);
}
} // anonymous namespace

void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto &Node : MGraphImpl->MRoots) {
Node->sortTopological(Node, MSchedule);
sortTopological(Node.lock(), MSchedule);
}
}
}
Expand All @@ -148,10 +121,19 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
}
}

// Recursively walk the graph to find exit nodes and connect up the inputs
// TODO: Consider caching exit nodes so we don't have to do this
for (auto &NodeImpl : MRoots) {
connectToExitNodes(NodeImpl, Inputs);
// Find all exit nodes in the current graph and register the Inputs as
// successors
for (auto &NodeImpl : MNodeStorage) {
if (NodeImpl->MSuccessors.size() == 0) {
for (auto &Input : Inputs) {
NodeImpl->registerSuccessor(Input, NodeImpl);
}
}
}

// Add all the new nodes to the node storage
for (auto &Node : NodeList) {
MNodeStorage.push_back(Node);
}

return this->add(Outputs);
Expand All @@ -175,7 +157,7 @@ std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(
*NewNodesIt = NodeCopy;
NodesMap.insert({Node, NodeCopy});
for (auto &NextNode : Node->MSuccessors) {
auto Successor = NodesMap.at(NextNode);
auto Successor = NodesMap.at(NextNode.lock());
NodeCopy->registerSuccessor(Successor, NodeCopy);
}
}
Expand All @@ -201,16 +183,9 @@ graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
// Add any deps from the vector of extra dependencies
Deps.insert(Deps.end(), MExtraDependencies.begin(), MExtraDependencies.end());

// TODO: Encapsulate in separate function to avoid duplication
if (!Deps.empty()) {
for (auto &N : Deps) {
N->registerSuccessor(NodeImpl, N); // register successor
this->removeRoot(NodeImpl); // remove receiver from root node
// list
}
} else {
this->addRoot(NodeImpl);
}
MNodeStorage.push_back(NodeImpl);

addDepsToNode(NodeImpl, Deps);

return NodeImpl;
}
Expand Down Expand Up @@ -285,17 +260,28 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
MemObj->markBeingUsedInGraph();
}
// Look through the graph for nodes which share this requirement
for (auto &NodePtr : MRoots) {
checkForRequirement(Req, NodePtr, UniqueDeps);
for (auto &Node : MNodeStorage) {
if (Node->hasRequirement(Req)) {
bool ShouldAddDep = true;
// If any of this node's successors have this requirement then we skip
// adding the current node as a dependency.
for (auto &Succ : Node->MSuccessors) {
if (Succ.lock()->hasRequirement(Req)) {
ShouldAddDep = false;
break;
}
}
if (ShouldAddDep) {
UniqueDeps.insert(Node);
}
}
}
}

// Add any nodes specified by event dependencies into the dependency list
for (auto &Dep : CommandGroup->getEvents()) {
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
if (UniqueDeps.find(NodeImpl->second) == UniqueDeps.end()) {
UniqueDeps.insert(NodeImpl->second);
}
UniqueDeps.insert(NodeImpl->second);
} else {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Event dependency from handler::depends_on does "
Expand All @@ -311,15 +297,9 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,

const std::shared_ptr<node_impl> &NodeImpl =
std::make_shared<node_impl>(CGType, std::move(CommandGroup));
if (!Deps.empty()) {
for (auto &N : Deps) {
N->registerSuccessor(NodeImpl, N); // register successor
this->removeRoot(NodeImpl); // remove receiver from root node
// list
}
} else {
this->addRoot(NodeImpl);
}
MNodeStorage.push_back(NodeImpl);

addDepsToNode(NodeImpl, Deps);

// Set barrier nodes as prerequisites (new start points) for subsequent nodes
if (CGType == sycl::detail::CG::Barrier) {
Expand Down Expand Up @@ -353,7 +333,7 @@ void graph_impl::searchDepthFirst(

for (auto &Root : MRoots) {
std::deque<std::shared_ptr<node_impl>> NodeStack;
if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) {
if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) {
break;
}
}
Expand Down Expand Up @@ -394,18 +374,15 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,

bool SrcFound = false;
bool DestFound = false;
auto CheckForNodes = [&](std::shared_ptr<node_impl> &Node,
std::deque<std::shared_ptr<node_impl>> &) {
if (Node == Src) {
SrcFound = true;
}
if (Node == Dest) {
DestFound = true;
}
return SrcFound && DestFound;
};
for (const auto &Node : MNodeStorage) {

SrcFound |= Node == Src;
DestFound |= Node == Dest;

searchDepthFirst(CheckForNodes);
if (SrcFound && DestFound) {
break;
}
}

if (!SrcFound) {
throw sycl::exception(make_error_code(sycl::errc::invalid),
Expand Down
Loading
Loading