Skip to content

Commit

Permalink
[SYCL][Graph] Fixes a bug in getNumberNodes() (#316)
Browse files Browse the repository at this point in the history
* [SYCL][Graph] Fixes a bug in getNumberNodes()

GetNumbersNodes() didn't manage correctly multiple paths graphs.
The new implementation relies on searchDepthFirst to solve this issue.

* [SYCL][Graph] Adds test to verify getNumberofNodes output

* [SYCL][Graph] Improve comments + typos
  • Loading branch information
mfrancepillois authored Sep 6, 2023
1 parent adfee38 commit 3703b98
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
14 changes: 10 additions & 4 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,17 @@ class graph_impl {

// Returns the number of nodes in the Graph
// @return Number of nodes in the Graph
size_t getNumberOfNodes() const {
size_t getNumberOfNodes() {
size_t NumberOfNodes = 0;
for (const auto &Node : MRoots) {
NumberOfNodes += Node->depthSearchCount();
}
auto CountFunc = [&](std::shared_ptr<node_impl> &Node,
std::deque<std::shared_ptr<node_impl>> &) {
if (!Node->MVisited) {
NumberOfNodes++;
}
return false;
};
searchDepthFirst(CountFunc);

return NumberOfNodes;
}

Expand Down
29 changes: 26 additions & 3 deletions sycl/unittests/Extensions/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,19 @@ bool depthSearchSuccessorCheck(

/// Submits four kernels with diamond dependency to the queue Q
/// @param Q Queue to submit nodes to.
void runKernels(queue Q) {
auto NodeA = Q.submit(
[&](sycl::handler &cgh) { cgh.single_task<TestKernel<>>([]() {}); });
/// @param Dep Events to add as previous dependencies to the node group
/// @return The event associated with the last kernel submitted
sycl::event runKernels(queue Q, std::vector<sycl::event> Dep = {}) {
sycl::event NodeA;
if (Dep.size() > 0) {
NodeA = Q.submit([&](sycl::handler &cgh) {
cgh.depends_on(Dep);
cgh.single_task<TestKernel<>>([]() {});
});
} else {
NodeA = Q.submit(
[&](sycl::handler &cgh) { cgh.single_task<TestKernel<>>([]() {}); });
}
auto NodeB = Q.submit([&](sycl::handler &cgh) {
cgh.depends_on(NodeA);
cgh.single_task<TestKernel<>>([]() {});
Expand All @@ -406,6 +416,7 @@ void runKernels(queue Q) {
cgh.depends_on({NodeB, NodeC});
cgh.single_task<TestKernel<>>([]() {});
});
return NodeD;
}

/// Submits four kernels without any additional dependencies the queue Q
Expand Down Expand Up @@ -2091,3 +2102,15 @@ TEST_F(CommandGraphTest, FillMemsetNodes) {
sycl::free(USMPtr, Queue);
}
}

TEST_F(CommandGraphTest, GetNumberOfNodes) {
// Create graph made of nodes linked as a double diamond
Graph.begin_recording(Queue);
auto Event = runKernels(Queue);
runKernels(Queue, {Event});
Graph.end_recording(Queue);

// Check the number of nodes returned by getNumberOfNodes
auto GraphImpl = sycl::detail::getSyclObjImpl(Graph);
EXPECT_EQ(GraphImpl->getNumberOfNodes(), 8lu);
}

0 comments on commit 3703b98

Please sign in to comment.