diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index e8543efa6d666..8e1667e189f06 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -82,8 +82,10 @@ void sortTopological(std::shared_ptr NodeImpl, for (auto &Succ : NodeImpl->MSuccessors) { // Check if we've already scheduled this node auto NextNode = Succ.lock(); - if (std::find(Schedule.begin(), Schedule.end(), NextNode) == Schedule.end()) + if (std::find(Schedule.begin(), Schedule.end(), NextNode) == + Schedule.end()) { sortTopological(NextNode, Schedule); + } } Schedule.push_front(NodeImpl); @@ -93,7 +95,7 @@ void sortTopological(std::shared_ptr NodeImpl, void exec_graph_impl::schedule() { if (MSchedule.empty()) { for (auto &Node : MGraphImpl->MRoots) { - sortTopological(Node, MSchedule); + sortTopological(Node.lock(), MSchedule); } } } @@ -264,11 +266,14 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType, // If any of this node's successors have this requirement then we skip // adding the current node as a dependency. for (auto &Succ : Node->MSuccessors) { - if (Succ.lock()->hasRequirement(Req)) + if (Succ.lock()->hasRequirement(Req)) { ShouldAddDep = false; + break; + } } - if (ShouldAddDep) + if (ShouldAddDep) { UniqueDeps.insert(Node); + } } } } @@ -328,7 +333,7 @@ void graph_impl::searchDepthFirst( for (auto &Root : MRoots) { std::deque> NodeStack; - if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) { + if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) { break; } } @@ -374,8 +379,9 @@ void graph_impl::makeEdge(std::shared_ptr Src, SrcFound |= Node == Src; DestFound |= Node == Dest; - if (SrcFound && DestFound) + if (SrcFound && DestFound) { break; + } } if (!SrcFound) { diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 8bb22e82b931a..f1fcb6b09751f 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -183,8 +183,8 @@ class node_impl { /// @param CompareContentOnly Skip comparisons related to graph structure, /// compare only the type and command groups of the nodes /// @return True if the two nodes are similar - bool isSimilar(std::shared_ptr Node, - bool CompareContentOnly = false) { + bool isSimilar(const std::shared_ptr &Node, + bool CompareContentOnly = false) const { if (!CompareContentOnly) { if (MSuccessors.size() != Node->MSuccessors.size()) return false; @@ -379,7 +379,8 @@ class graph_impl { sycl::device getDevice() const { return MDevice; } /// List of root nodes. - std::set> MRoots; + std::set, std::owner_less>> + MRoots; /// Storage for all nodes contained within a graph. Nodes are connected to /// each other via weak_ptrs and so do not extend each other's lifetimes. @@ -433,8 +434,8 @@ class graph_impl { /// @param NodeA pointer to the first node for comparison /// @param NodeB pointer to the second node for comparison /// @return true is same structure found, false otherwise - static bool checkNodeRecursive(std::shared_ptr NodeA, - std::shared_ptr NodeB) { + static bool checkNodeRecursive(const std::shared_ptr &NodeA, + const std::shared_ptr &NodeB) { size_t FoundCnt = 0; for (std::weak_ptr &SuccA : NodeA->MSuccessors) { for (std::weak_ptr &SuccB : NodeB->MSuccessors) { @@ -509,10 +510,13 @@ class graph_impl { } size_t RootsFound = 0; - for (std::shared_ptr NodeA : MRoots) { - for (std::shared_ptr NodeB : Graph->MRoots) { - if (NodeA->isSimilar(NodeB)) { - if (checkNodeRecursive(NodeA, NodeB)) { + for (std::weak_ptr NodeA : MRoots) { + for (std::weak_ptr NodeB : Graph->MRoots) { + auto NodeALocked = NodeA.lock(); + auto NodeBLocked = NodeB.lock(); + + if (NodeALocked->isSimilar(NodeBLocked)) { + if (checkNodeRecursive(NodeALocked, NodeBLocked)) { RootsFound++; break; } diff --git a/sycl/unittests/Extensions/CommandGraph.cpp b/sycl/unittests/Extensions/CommandGraph.cpp index 3dfe574c8e4a6..b37f9d7221ca1 100644 --- a/sycl/unittests/Extensions/CommandGraph.cpp +++ b/sycl/unittests/Extensions/CommandGraph.cpp @@ -497,7 +497,8 @@ TEST_F(CommandGraphTest, AddNode) { ASSERT_NE(sycl::detail::getSyclObjImpl(Node1), nullptr); ASSERT_FALSE(sycl::detail::getSyclObjImpl(Node1)->isEmpty()); ASSERT_EQ(GraphImpl->MRoots.size(), 1lu); - ASSERT_EQ(*GraphImpl->MRoots.begin(), sycl::detail::getSyclObjImpl(Node1)); + ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(), + sycl::detail::getSyclObjImpl(Node1)); ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.empty()); ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty()); @@ -1269,7 +1270,8 @@ TEST_F(CommandGraphTest, EnqueueBarrier) { // / \ // (4) (5) ASSERT_EQ(GraphImpl->MRoots.size(), 3lu); - for (auto Node : GraphImpl->MRoots) { + for (auto Root : GraphImpl->MRoots) { + auto Node = Root.lock(); ASSERT_EQ(Node->MSuccessors.size(), 1lu); auto BarrierNode = Node->MSuccessors.front().lock(); ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier); @@ -1309,7 +1311,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) { // / \ // (4) (5) ASSERT_EQ(GraphImpl->MRoots.size(), 3lu); - for (auto Node : GraphImpl->MRoots) { + for (auto Root : GraphImpl->MRoots) { + auto Node = Root.lock(); ASSERT_EQ(Node->MSuccessors.size(), 1lu); auto BarrierNode = Node->MSuccessors.front().lock(); ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier); @@ -1352,7 +1355,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitList) { // / \ / // (4) (5) ASSERT_EQ(GraphImpl->MRoots.size(), 3lu); - for (auto Node : GraphImpl->MRoots) { + for (auto Root : GraphImpl->MRoots) { + auto Node = Root.lock(); ASSERT_EQ(Node->MSuccessors.size(), 1lu); auto SuccNode = Node->MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CG::Barrier) { @@ -1408,7 +1412,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitListMultipleQueues) { // \|/ // (B2) ASSERT_EQ(GraphImpl->MRoots.size(), 3lu); - for (auto Node : GraphImpl->MRoots) { + for (auto Root : GraphImpl->MRoots) { + auto Node = Root.lock(); ASSERT_EQ(Node->MSuccessors.size(), 1lu); auto SuccNode = Node->MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CG::Barrier) { @@ -1470,7 +1475,8 @@ TEST_F(CommandGraphTest, EnqueueMultipleBarrier) { // / | \ // (6) (7) (8) (those nodes also have B1 as a predecessor) ASSERT_EQ(GraphImpl->MRoots.size(), 3lu); - for (auto Node : GraphImpl->MRoots) { + for (auto Root : GraphImpl->MRoots) { + auto Node = Root.lock(); ASSERT_EQ(Node->MSuccessors.size(), 1lu); auto SuccNode = Node->MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CG::Barrier) { @@ -1824,7 +1830,7 @@ TEST_F(CommandGraphTest, MakeEdgeErrors) { auto NodeBImpl = sycl::detail::getSyclObjImpl(NodeB); ASSERT_EQ(GraphImpl->MRoots.size(), 1lu); - ASSERT_EQ(*(GraphImpl->MRoots.begin()), NodeAImpl); + ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(), NodeAImpl); ASSERT_EQ(NodeAImpl->MSuccessors.size(), 1lu); ASSERT_EQ(NodeAImpl->MPredecessors.size(), 0lu); @@ -2070,7 +2076,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) { ASSERT_EQ(GraphImpl->MRoots.size(), 1lu); // Check structure graph - auto CurrentNode = *GraphImpl->MRoots.begin(); + auto CurrentNode = (*GraphImpl->MRoots.begin()).lock(); for (size_t i = 1; i <= GraphImpl->getNumberOfNodes(); i++) { EXPECT_LE(CurrentNode->MSuccessors.size(), 1lu);