diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index a6eb5d7eaba6f..5cae341e62073 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -15,6 +15,8 @@ #include #include +#include + // Developer switch to use emulation mode on all backends, even those that // report native support, this is useful for debugging. #define FORCE_EMULATION_MODE 0 @@ -71,6 +73,40 @@ bool checkForRequirement(sycl::detail::AccessorImplHost *Req, } return SuccessorAddedDep; } + +/// Visits a node on the graph and it's successors recursively in a depth-first +/// approach. +/// @param[in] Node The current node being visited. +/// @param[in,out] VisitedNodes A set of unique nodes which have already been +/// visited. +/// @param[in] NodeStack Stack of nodes which are currently being visited on the +/// current path through the graph. +/// @param[in] NodeFunc The function object to be run on each node. A return +/// value of true indicates the search should be ended immediately and the +/// function will return. +/// @return True if the search should end immediately, false if not. +bool visitNodeDepthFirst( + std::shared_ptr Node, + std::set> &VisitedNodes, + std::deque> &NodeStack, + std::function &, + std::deque> &)> + NodeFunc) { + auto EarlyReturn = NodeFunc(Node, NodeStack); + if (EarlyReturn) { + return true; + } + NodeStack.push_back(Node); + Node->MVisited = true; + VisitedNodes.emplace(Node); + for (auto &Successor : Node->MSuccessors) { + if (visitNodeDepthFirst(Successor, VisitedNodes, NodeStack, NodeFunc)) { + return true; + } + } + NodeStack.pop_back(); + return false; +} } // anonymous namespace void exec_graph_impl::schedule() { @@ -226,6 +262,105 @@ bool graph_impl::clearQueues() { return AnyQueuesCleared; } +void graph_impl::searchDepthFirst( + std::function &, + std::deque> &)> + NodeFunc) { + // Track nodes visited during the search which can be used by NodeFunc in + // depth first search queries. Currently unusued but is an + // integral part of depth first searches. + std::set> VisitedNodes; + + for (auto &Root : MRoots) { + std::deque> NodeStack; + if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) { + break; + } + } + + // Reset the visited status of all nodes encountered in the search. + for (auto &Node : VisitedNodes) { + Node->MVisited = false; + } +} + +bool graph_impl::checkForCycles() { + // Using a depth-first search and checking if we vist a node more than once in + // the current path to identify if there are cycles. + bool CycleFound = false; + auto CheckFunc = [&](std::shared_ptr &Node, + std::deque> &NodeStack) { + // If the current node has previously been found in the current path through + // the graph then we have a cycle and we end the search early. + if (std::find(NodeStack.begin(), NodeStack.end(), Node) != + NodeStack.end()) { + CycleFound = true; + return true; + } + return false; + }; + searchDepthFirst(CheckFunc); + return CycleFound; +} + +void graph_impl::makeEdge(std::shared_ptr Src, + std::shared_ptr Dest) { + if (MRecordingQueues.size()) { + throw sycl::exception(make_error_code(sycl::errc::invalid), + "make_edge() cannot be called when a queue is " + "currently recording commands to a graph."); + } + if (Src == Dest) { + throw sycl::exception( + make_error_code(sycl::errc::invalid), + "make_edge() cannot be called when Src and Dest are the same."); + } + + bool SrcFound = false; + bool DestFound = false; + auto CheckForNodes = [&](std::shared_ptr &Node, + std::deque> &) { + if (Node == Src) { + SrcFound = true; + } + if (Node == Dest) { + DestFound = true; + } + return SrcFound && DestFound; + }; + + searchDepthFirst(CheckForNodes); + + if (!SrcFound) { + throw sycl::exception(make_error_code(sycl::errc::invalid), + "Src must be a node inside the graph."); + } + if (!DestFound) { + throw sycl::exception(make_error_code(sycl::errc::invalid), + "Dest must be a node inside the graph."); + } + + // We need to add the edges first before checking for cycles + Src->registerSuccessor(Dest, Src); + + // We can skip cycle checks if either Dest has no successors (cycle not + // possible) or cycle checks have been disabled with the no_cycle_check + // property; + if (Dest->MSuccessors.empty() || !MSkipCycleChecks) { + bool CycleFound = checkForCycles(); + + if (CycleFound) { + // Remove the added successor and predecessor + Src->MSuccessors.pop_back(); + Dest->MPredecessors.pop_back(); + + throw sycl::exception(make_error_code(sycl::errc::invalid), + "Command graphs cannot contain cycles."); + } + } + removeRoot(Dest); // remove receiver from root node list +} + // Check if nodes are empty and if so loop back through predecessors until we // find the real dependency. void exec_graph_impl::findRealDeps( @@ -463,8 +598,9 @@ exec_graph_impl::enqueue(const std::shared_ptr &Queue, modifiable_command_graph::modifiable_command_graph( const sycl::context &SyclContext, const sycl::device &SyclDevice, - const sycl::property_list &) - : impl(std::make_shared(SyclContext, SyclDevice)) {} + const sycl::property_list &PropList) + : impl(std::make_shared(SyclContext, SyclDevice, + PropList)) {} node modifiable_command_graph::addImpl(const std::vector &Deps) { std::vector> DepImpls; @@ -494,9 +630,7 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) { std::shared_ptr ReceiverImpl = sycl::detail::getSyclObjImpl(Dest); - SenderImpl->registerSuccessor(ReceiverImpl, - SenderImpl); // register successor - impl->removeRoot(ReceiverImpl); // remove receiver from root node list + impl->makeEdge(SenderImpl, ReceiverImpl); } command_graph diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 4635f4154a9c9..c5f7efc084ffa 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,9 @@ class node_impl { /// Command group object which stores all args etc needed to enqueue the node std::unique_ptr MCommandGroup; + /// Used for tracking visited status during cycle checks. + bool MVisited = false; + /// Add successor to the node. /// @param Node Node to add as a successor. /// @param Prev Predecessor to \p node being added as successor. @@ -51,6 +55,10 @@ class node_impl { /// use a raw \p this pointer, so the extra \Prev parameter is passed. void registerSuccessor(const std::shared_ptr &Node, const std::shared_ptr &Prev) { + if (std::find(MSuccessors.begin(), MSuccessors.end(), Node) != + MSuccessors.end()) { + return; + } MSuccessors.push_back(Node); Node->registerPredecessor(Prev); } @@ -58,6 +66,12 @@ class node_impl { /// Add predecessor to the node. /// @param Node Node to add as a predecessor. void registerPredecessor(const std::shared_ptr &Node) { + if (std::find_if(MPredecessors.begin(), MPredecessors.end(), + [&Node](const std::weak_ptr &Ptr) { + return Ptr.lock() == Node; + }) != MPredecessors.end()) { + return; + } MPredecessors.push_back(Node); } @@ -206,9 +220,6 @@ class node_impl { case sycl::detail::CG::CGTYPE::AdviseUSM: Stream << "CGAdviseUSM \\n"; break; - case sycl::detail::CG::CGTYPE::CodeplayInteropTask: - Stream << "CGInteropTask \\n"; - break; case sycl::detail::CG::CGTYPE::CodeplayHostTask: Stream << "CGHostTask \\n"; break; @@ -331,9 +342,15 @@ class graph_impl { /// Constructor. /// @param SyclContext Context to use for graph. /// @param SyclDevice Device to create nodes with. - graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice) + /// @param PropList Optional list of properties. + graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, + const sycl::property_list &PropList = {}) : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(), - MEventsMap(), MInorderQueueMap() {} + MEventsMap(), MInorderQueueMap() { + if (PropList.has_property()) { + MSkipCycleChecks = true; + } + } /// Insert node into list of root nodes. /// @param Root Node to add to list of root nodes. @@ -557,8 +574,32 @@ class graph_impl { return true; } + /// Make an edge between two nodes in the graph. Performs some mandatory + /// error checks as well as an optional check for cycles introduced by making + /// this edge. + /// @param Src The source of the new edge. + /// @param Dest The destination of the new edge. + void makeEdge(std::shared_ptr Src, + std::shared_ptr Dest); private: + /// Iterate over the graph depth-first and run \p NodeFunc on each node. + /// @param NodeFunc A function which receives as input a node in the graph to + /// perform operations on as well as the stack of nodes encountered in the + /// current path. The return value of this function determines whether an + /// early exit is triggered, if true the depth-first search will end + /// immediately and no further nodes will be visited. + void + searchDepthFirst(std::function &, + std::deque> &)> + NodeFunc); + + /// Check the graph for cycles by performing a depth-first search of the + /// graph. If a node is visited more than once in a given path through the + /// graph, a cycle is present and the search ends immediately. + /// @return True if a cycle is detected, false if not. + bool checkForCycles(); + /// Context associated with this graph. sycl::context MContext; /// Device associated with this graph. All graph nodes will execute on this @@ -576,6 +617,9 @@ class graph_impl { std::map, std::shared_ptr, std::owner_less>> MInorderQueueMap; + /// Controls whether we skip the cycle checks in makeEdge, set by the presence + /// of the no_cycle_check property on construction. + bool MSkipCycleChecks = false; }; /// Class representing the implementation of command_graph. diff --git a/sycl/test-e2e/Graph/Explicit/cycle_error.cpp b/sycl/test-e2e/Graph/Explicit/cycle_error.cpp new file mode 100644 index 0000000000000..2ca29aa67b9cf --- /dev/null +++ b/sycl/test-e2e/Graph/Explicit/cycle_error.cpp @@ -0,0 +1,86 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that introducing a cycle to the graph will throw when +// property::graph::no_cycle_check is not passed to the graph constructor and +// will not throw when it is. + +#include "../graph_common.hpp" + +void CreateGraphWithCyclesTest(bool DisableCycleChecks) { + + // If we are testing without cycle checks we need to do multiple iterations so + // we can test multiple types of cycle, since introducing a cycle with no + // checks may put the graph into an undefined state. + const size_t Iterations = DisableCycleChecks ? 2 : 1; + + queue Queue; + + property_list Props; + + if (DisableCycleChecks) { + Props = {ext::oneapi::experimental::property::graph::no_cycle_check{}}; + } + + for (size_t i = 0; i < Iterations; i++) { + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device(), Props}; + + auto NodeA = Graph.add([&](sycl::handler &CGH) { + CGH.single_task([=]() {}); + }); + auto NodeB = Graph.add([&](sycl::handler &CGH) { + CGH.single_task([=]() {}); + }); + auto NodeC = Graph.add([&](sycl::handler &CGH) { + CGH.single_task([=]() {}); + }); + + // Make normal edges + std::error_code ErrorCode = sycl::make_error_code(sycl::errc::success); + try { + Graph.make_edge(NodeA, NodeB); + Graph.make_edge(NodeB, NodeC); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + + assert(ErrorCode == sycl::errc::success); + + // Introduce cycles to the graph. If we are performing cycle checks we can + // test both cycles, if they are disabled we need to test one per iteration. + if (i == 0 || !DisableCycleChecks) { + ErrorCode = sycl::make_error_code(sycl::errc::success); + try { + Graph.make_edge(NodeC, NodeA); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + + assert(ErrorCode == + (DisableCycleChecks ? sycl::errc::success : sycl::errc::invalid)); + } + + if (i == 1 || !DisableCycleChecks) { + ErrorCode = sycl::make_error_code(sycl::errc::success); + try { + Graph.make_edge(NodeC, NodeB); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + + assert(ErrorCode == + (DisableCycleChecks ? sycl::errc::success : sycl::errc::invalid)); + } + } +} + +int main() { + // Test with cycle checks + CreateGraphWithCyclesTest(false); + // Test without cycle checks + CreateGraphWithCyclesTest(true); + + return 0; +} diff --git a/sycl/unittests/Extensions/CommandGraph.cpp b/sycl/unittests/Extensions/CommandGraph.cpp index 1193439f196cc..3de65e97514dd 100644 --- a/sycl/unittests/Extensions/CommandGraph.cpp +++ b/sycl/unittests/Extensions/CommandGraph.cpp @@ -592,3 +592,105 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmptyLast) { ASSERT_EQ(*ScheduleIt, PtrNode2); ASSERT_EQ(InOrderQueue.get_context(), GraphExecImpl->getContext()); } + +TEST_F(CommandGraphTest, MakeEdgeErrors) { + // Set up some nodes in the graph + auto NodeA = Graph.add( + [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); + auto NodeB = Graph.add( + [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); + + // Test error on calling make_edge when a queue is recording to the graph + Graph.begin_recording(Queue); + ASSERT_THROW( + { + try { + Graph.make_edge(NodeA, NodeB); + } catch (const sycl::exception &e) { + ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid)); + throw; + } + }, + sycl::exception); + + Graph.end_recording(Queue); + + // Test error on Src and Dest being the same + ASSERT_THROW( + { + try { + Graph.make_edge(NodeA, NodeA); + } catch (const sycl::exception &e) { + ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid)); + throw; + } + }, + sycl::exception); + + // Test Src or Dest not being found in the graph + experimental::command_graph GraphOther{ + Queue.get_context(), Queue.get_device()}; + auto NodeOther = GraphOther.add( + [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); + + ASSERT_THROW( + { + try { + Graph.make_edge(NodeA, NodeOther); + } catch (const sycl::exception &e) { + ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid)); + throw; + } + }, + sycl::exception); + ASSERT_THROW( + { + try { + Graph.make_edge(NodeOther, NodeB); + } catch (const sycl::exception &e) { + ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid)); + throw; + } + }, + sycl::exception); + + // Test that adding a cycle with cycle checks leaves the graph in the correct + // state. + + auto CheckGraphStructure = [&]() { + auto GraphImpl = sycl::detail::getSyclObjImpl(Graph); + auto NodeAImpl = sycl::detail::getSyclObjImpl(NodeA); + auto NodeBImpl = sycl::detail::getSyclObjImpl(NodeB); + + ASSERT_EQ(GraphImpl->MRoots.size(), 1lu); + ASSERT_EQ(*(GraphImpl->MRoots.begin()), NodeAImpl); + + ASSERT_EQ(NodeAImpl->MSuccessors.size(), 1lu); + ASSERT_EQ(NodeAImpl->MPredecessors.size(), 0lu); + ASSERT_EQ(NodeAImpl->MSuccessors.front(), NodeBImpl); + + ASSERT_EQ(NodeBImpl->MSuccessors.size(), 0lu); + ASSERT_EQ(NodeBImpl->MPredecessors.size(), 1lu); + ASSERT_EQ(NodeBImpl->MPredecessors.front().lock(), NodeAImpl); + }; + // Make a normal edge + ASSERT_NO_THROW(Graph.make_edge(NodeA, NodeB)); + + // Check the expected structure of the graph + CheckGraphStructure(); + + // Introduce a cycle, make sure it throws + ASSERT_THROW( + { + try { + Graph.make_edge(NodeB, NodeA); + } catch (const sycl::exception &e) { + ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid)); + throw; + } + }, + sycl::exception); + + // Re-check graph structure to make sure the graph state has not been modified + CheckGraphStructure(); +}