Skip to content

Commit

Permalink
[SYCL][Graph] Addressing PR feedback
Browse files Browse the repository at this point in the history
- Style fixes
- Make MRoots weak_ptrs instead of shared
  • Loading branch information
Bensuo committed Oct 23, 2023
1 parent 508ee90 commit c9cfc73
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
18 changes: 12 additions & 6 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ void sortTopological(std::shared_ptr<node_impl> NodeImpl,
for (auto &Succ : NodeImpl->MSuccessors) {
// Check if we've already scheduled this node
auto NextNode = Succ.lock();
if (std::find(Schedule.begin(), Schedule.end(), NextNode) == Schedule.end())
if (std::find(Schedule.begin(), Schedule.end(), NextNode) ==
Schedule.end()) {
sortTopological(NextNode, Schedule);
}
}

Schedule.push_front(NodeImpl);
Expand All @@ -93,7 +95,7 @@ void sortTopological(std::shared_ptr<node_impl> NodeImpl,
void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto &Node : MGraphImpl->MRoots) {
sortTopological(Node, MSchedule);
sortTopological(Node.lock(), MSchedule);
}
}
}
Expand Down Expand Up @@ -264,11 +266,14 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
// If any of this node's successors have this requirement then we skip
// adding the current node as a dependency.
for (auto &Succ : Node->MSuccessors) {
if (Succ.lock()->hasRequirement(Req))
if (Succ.lock()->hasRequirement(Req)) {
ShouldAddDep = false;
break;
}
}
if (ShouldAddDep)
if (ShouldAddDep) {
UniqueDeps.insert(Node);
}
}
}
}
Expand Down Expand Up @@ -328,7 +333,7 @@ void graph_impl::searchDepthFirst(

for (auto &Root : MRoots) {
std::deque<std::shared_ptr<node_impl>> NodeStack;
if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) {
if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) {
break;
}
}
Expand Down Expand Up @@ -374,8 +379,9 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
SrcFound |= Node == Src;
DestFound |= Node == Dest;

if (SrcFound && DestFound)
if (SrcFound && DestFound) {
break;
}
}

if (!SrcFound) {
Expand Down
22 changes: 13 additions & 9 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ class node_impl {
/// @param CompareContentOnly Skip comparisons related to graph structure,
/// compare only the type and command groups of the nodes
/// @return True if the two nodes are similar
bool isSimilar(std::shared_ptr<node_impl> Node,
bool CompareContentOnly = false) {
bool isSimilar(const std::shared_ptr<node_impl> &Node,
bool CompareContentOnly = false) const {
if (!CompareContentOnly) {
if (MSuccessors.size() != Node->MSuccessors.size())
return false;
Expand Down Expand Up @@ -379,7 +379,8 @@ class graph_impl {
sycl::device getDevice() const { return MDevice; }

/// List of root nodes.
std::set<std::shared_ptr<node_impl>> MRoots;
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
MRoots;

/// Storage for all nodes contained within a graph. Nodes are connected to
/// each other via weak_ptrs and so do not extend each other's lifetimes.
Expand Down Expand Up @@ -433,8 +434,8 @@ class graph_impl {
/// @param NodeA pointer to the first node for comparison
/// @param NodeB pointer to the second node for comparison
/// @return true is same structure found, false otherwise
static bool checkNodeRecursive(std::shared_ptr<node_impl> NodeA,
std::shared_ptr<node_impl> NodeB) {
static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
const std::shared_ptr<node_impl> &NodeB) {
size_t FoundCnt = 0;
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
Expand Down Expand Up @@ -509,10 +510,13 @@ class graph_impl {
}

size_t RootsFound = 0;
for (std::shared_ptr<node_impl> NodeA : MRoots) {
for (std::shared_ptr<node_impl> NodeB : Graph->MRoots) {
if (NodeA->isSimilar(NodeB)) {
if (checkNodeRecursive(NodeA, NodeB)) {
for (std::weak_ptr<node_impl> NodeA : MRoots) {
for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
auto NodeALocked = NodeA.lock();
auto NodeBLocked = NodeB.lock();

if (NodeALocked->isSimilar(NodeBLocked)) {
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
RootsFound++;
break;
}
Expand Down
22 changes: 14 additions & 8 deletions sycl/unittests/Extensions/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ TEST_F(CommandGraphTest, AddNode) {
ASSERT_NE(sycl::detail::getSyclObjImpl(Node1), nullptr);
ASSERT_FALSE(sycl::detail::getSyclObjImpl(Node1)->isEmpty());
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
ASSERT_EQ(*GraphImpl->MRoots.begin(), sycl::detail::getSyclObjImpl(Node1));
ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(),
sycl::detail::getSyclObjImpl(Node1));
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());

Expand Down Expand Up @@ -1269,7 +1270,8 @@ TEST_F(CommandGraphTest, EnqueueBarrier) {
// / \
// (4) (5)
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
for (auto Node : GraphImpl->MRoots) {
for (auto Root : GraphImpl->MRoots) {
auto Node = Root.lock();
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
auto BarrierNode = Node->MSuccessors.front().lock();
ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier);
Expand Down Expand Up @@ -1309,7 +1311,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) {
// / \
// (4) (5)
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
for (auto Node : GraphImpl->MRoots) {
for (auto Root : GraphImpl->MRoots) {
auto Node = Root.lock();
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
auto BarrierNode = Node->MSuccessors.front().lock();
ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CG::Barrier);
Expand Down Expand Up @@ -1352,7 +1355,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitList) {
// / \ /
// (4) (5)
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
for (auto Node : GraphImpl->MRoots) {
for (auto Root : GraphImpl->MRoots) {
auto Node = Root.lock();
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
auto SuccNode = Node->MSuccessors.front().lock();
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
Expand Down Expand Up @@ -1408,7 +1412,8 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitListMultipleQueues) {
// \|/
// (B2)
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
for (auto Node : GraphImpl->MRoots) {
for (auto Root : GraphImpl->MRoots) {
auto Node = Root.lock();
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
auto SuccNode = Node->MSuccessors.front().lock();
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
Expand Down Expand Up @@ -1470,7 +1475,8 @@ TEST_F(CommandGraphTest, EnqueueMultipleBarrier) {
// / | \
// (6) (7) (8) (those nodes also have B1 as a predecessor)
ASSERT_EQ(GraphImpl->MRoots.size(), 3lu);
for (auto Node : GraphImpl->MRoots) {
for (auto Root : GraphImpl->MRoots) {
auto Node = Root.lock();
ASSERT_EQ(Node->MSuccessors.size(), 1lu);
auto SuccNode = Node->MSuccessors.front().lock();
if (SuccNode->MCGType == sycl::detail::CG::Barrier) {
Expand Down Expand Up @@ -1824,7 +1830,7 @@ TEST_F(CommandGraphTest, MakeEdgeErrors) {
auto NodeBImpl = sycl::detail::getSyclObjImpl(NodeB);

ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
ASSERT_EQ(*(GraphImpl->MRoots.begin()), NodeAImpl);
ASSERT_EQ((*GraphImpl->MRoots.begin()).lock(), NodeAImpl);

ASSERT_EQ(NodeAImpl->MSuccessors.size(), 1lu);
ASSERT_EQ(NodeAImpl->MPredecessors.size(), 0lu);
Expand Down Expand Up @@ -2070,7 +2076,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) {
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);

// Check structure graph
auto CurrentNode = *GraphImpl->MRoots.begin();
auto CurrentNode = (*GraphImpl->MRoots.begin()).lock();
for (size_t i = 1; i <= GraphImpl->getNumberOfNodes(); i++) {
EXPECT_LE(CurrentNode->MSuccessors.size(), 1lu);

Expand Down

0 comments on commit c9cfc73

Please sign in to comment.