Skip to content

Commit

Permalink
Merge pull request #241 from reble/julianmi/expand-unittests
Browse files Browse the repository at this point in the history
[SYCL][Graph] Expand unittests
  • Loading branch information
julianmi authored Jun 28, 2023
2 parents 66c696c + 05bead8 commit 9368f50
Showing 1 changed file with 229 additions and 25 deletions.
254 changes: 229 additions & 25 deletions sycl/unittests/Extensions/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,117 @@ class CommandGraphTest : public ::testing::Test {
TEST_F(CommandGraphTest, AddNode) {
auto GraphImpl = sycl::detail::getSyclObjImpl(Graph);

ASSERT_TRUE(GraphImpl->MRoots.size() == 0);
ASSERT_TRUE(GraphImpl->MRoots.empty());

auto Node1 = Graph.add([&](sycl::handler &cgh) {});

ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1) != nullptr);
ASSERT_TRUE(GraphImpl->MRoots.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 0);
auto Node1 = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
ASSERT_NE(sycl::detail::getSyclObjImpl(Node1), nullptr);
ASSERT_FALSE(sycl::detail::getSyclObjImpl(Node1)->isEmpty());
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
ASSERT_EQ(*GraphImpl->MRoots.begin(), sycl::detail::getSyclObjImpl(Node1));
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());

// Add a node which depends on the first
auto Node2 = Graph.add([&](sycl::handler &cgh) {},
{experimental::property::node::depends_on(Node1)});
ASSERT_TRUE(GraphImpl->MRoots.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.front() ==
sycl::detail::getSyclObjImpl(Node2));
auto Node2Deps = experimental::property::node::depends_on(Node1);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2Deps.get_dependencies().front()),
sycl::detail::getSyclObjImpl(Node1));
auto Node2 = Graph.add([&](sycl::handler &cgh) {}, {Node2Deps});
ASSERT_NE(sycl::detail::getSyclObjImpl(Node2), nullptr);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->isEmpty());
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.front(),
sycl::detail::getSyclObjImpl(Node2));
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size(), 1lu);

// Add a third node which depends on both
auto Node3 =
Graph.add([&](sycl::handler &cgh) {},
{experimental::property::node::depends_on(Node1, Node2)});
ASSERT_TRUE(GraphImpl->MRoots.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 2);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.size() == 1);
ASSERT_NE(sycl::detail::getSyclObjImpl(Node3), nullptr);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node3)->isEmpty());
ASSERT_EQ(GraphImpl->MRoots.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size(), 2lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.size(), 1lu);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node3)->MPredecessors.size(), 2lu);

// Add a fourth node without any dependencies on the others
auto Node4 = Graph.add([&](sycl::handler &cgh) {});
ASSERT_TRUE(GraphImpl->MRoots.size() == 2);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 2);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node3)->MSuccessors.size() == 0);
ASSERT_NE(sycl::detail::getSyclObjImpl(Node4), nullptr);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node4)->isEmpty());
ASSERT_EQ(GraphImpl->MRoots.size(), 2lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size(), 2lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.size(), 1lu);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node3)->MSuccessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node3)->MPredecessors.size(), 2lu);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node4)->MPredecessors.empty());
}

TEST_F(CommandGraphTest, Finalize) {
auto GraphImpl = sycl::detail::getSyclObjImpl(Graph);

sycl::buffer<int> Buf(1);
auto Node1 = Graph.add([&](sycl::handler &cgh) {
sycl::accessor A(Buf, cgh, sycl::write_only, sycl::no_init);
cgh.single_task<class TestKernel1>([=]() { A[0] = 1; });
});

// Add independent node
auto Node2 = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });

// Add a node that depends on Node1 due to the accessor
auto Node3 = Graph.add([&](sycl::handler &cgh) {
sycl::accessor A(Buf, cgh, sycl::write_only, sycl::no_init);
cgh.single_task<class TestKernel2>([=]() { A[0] = 3; });
});

// Guarantee order of independent nodes 1 and 2
Graph.make_edge(Node2, Node1);

auto GraphExec = Graph.finalize();
auto GraphExecImpl = sycl::detail::getSyclObjImpl(GraphExec);

// The final schedule should contain three nodes in order: 2->1->3
auto Schedule = GraphExecImpl->getSchedule();
ASSERT_EQ(Schedule.size(), 3ul);
auto ScheduleIt = Schedule.begin();
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node2));
ScheduleIt++;
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node1));
ScheduleIt++;
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node3));
ASSERT_EQ(Queue.get_context(), GraphExecImpl->getContext());
}

TEST_F(CommandGraphTest, MakeEdge) {
auto GraphImpl = sycl::detail::getSyclObjImpl(Graph);

auto Node1 = Graph.add([&](sycl::handler &cgh) {});
// Add two independent nodes
auto Node1 = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
auto Node2 = Graph.add([&](sycl::handler &cgh) {});
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 0);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size() == 0);
ASSERT_EQ(GraphImpl->MRoots.size(), 2ul);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.empty());

// Connect nodes and verify order
Graph.make_edge(Node1, Node2);

ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size() == 1);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size() == 1);
ASSERT_EQ(GraphImpl->MRoots.size(), 1ul);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1)->MSuccessors.front(),
sycl::detail::getSyclObjImpl(Node2));
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node1)->MPredecessors.empty());
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2)->MSuccessors.empty());
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2)->MPredecessors.size(), 1lu);
}

TEST_F(CommandGraphTest, BeginEndRecording) {
Expand Down Expand Up @@ -157,3 +224,140 @@ TEST_F(CommandGraphTest, BeginEndRecording) {
// Vector end should still return true as Queue will have state changed
ASSERT_TRUE(Graph.end_recording({Queue, Queue2}));
}

TEST_F(CommandGraphTest, GetCGCopy) {
auto Node1 = Graph.add([&](sycl::handler &cgh) {});
auto Node2 = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); },
{experimental::property::node::depends_on(Node1)});

// Get copy of CG of Node2 and check equality
auto Node2Imp = sycl::detail::getSyclObjImpl(Node2);
auto Node2CGCopy = Node2Imp->getCGCopy();
ASSERT_EQ(Node2CGCopy->getType(), Node2Imp->MCGType);
ASSERT_EQ(Node2CGCopy->getType(), sycl::detail::CG::Kernel);
ASSERT_EQ(Node2CGCopy->getType(), Node2Imp->MCommandGroup->getType());
ASSERT_EQ(Node2CGCopy->getAccStorage(),
Node2Imp->MCommandGroup->getAccStorage());
ASSERT_EQ(Node2CGCopy->getArgsStorage(),
Node2Imp->MCommandGroup->getArgsStorage());
ASSERT_EQ(Node2CGCopy->getEvents(), Node2Imp->MCommandGroup->getEvents());
ASSERT_EQ(Node2CGCopy->getRequirements(),
Node2Imp->MCommandGroup->getRequirements());
ASSERT_EQ(Node2CGCopy->getSharedPtrStorage(),
Node2Imp->MCommandGroup->getSharedPtrStorage());
}
TEST_F(CommandGraphTest, SubGraph) {
// Add sub-graph with two nodes
auto Node1Graph = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
auto Node2Graph = Graph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); },
{experimental::property::node::depends_on(Node1Graph)});
auto GraphExec = Graph.finalize();

// Add node to main graph followed by sub-graph and another node
experimental::command_graph MainGraph(Queue.get_context(), Dev);
auto Node1MainGraph = MainGraph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
auto Node2MainGraph =
MainGraph.add([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); },
{experimental::property::node::depends_on(Node1MainGraph)});
auto Node3MainGraph = MainGraph.add(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); },
{experimental::property::node::depends_on(Node2MainGraph)});

// Assert order of the added sub-graph
ASSERT_NE(sycl::detail::getSyclObjImpl(Node2MainGraph), nullptr);
ASSERT_TRUE(sycl::detail::getSyclObjImpl(Node2MainGraph)->isEmpty());
ASSERT_EQ(sycl::detail::getSyclObjImpl(MainGraph)->MRoots.size(), 1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1MainGraph)->MSuccessors.size(),
1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1MainGraph)->MSuccessors.front(),
sycl::detail::getSyclObjImpl(Node1Graph));
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2MainGraph)->MSuccessors.size(),
1lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node1MainGraph)->MPredecessors.size(),
0lu);
ASSERT_EQ(sycl::detail::getSyclObjImpl(Node2MainGraph)->MPredecessors.size(),
1lu);

// Finalize main graph and check schedule
auto MainGraphExec = MainGraph.finalize();
auto MainGraphExecImpl = sycl::detail::getSyclObjImpl(MainGraphExec);
auto Schedule = MainGraphExecImpl->getSchedule();
auto ScheduleIt = Schedule.begin();
ASSERT_EQ(Schedule.size(), 4ul);
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node1MainGraph));
ScheduleIt++;
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node1Graph));
ScheduleIt++;
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node2Graph));
ScheduleIt++;
ASSERT_EQ(*ScheduleIt, sycl::detail::getSyclObjImpl(Node3MainGraph));
ASSERT_EQ(Queue.get_context(), MainGraphExecImpl->getContext());
}

TEST_F(CommandGraphTest, RecordSubGraph) {
// Record sub-graph with two nodes
Graph.begin_recording(Queue);
auto Node1Graph = Queue.submit(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
auto Node2Graph = Queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(Node1Graph);
cgh.single_task<class TestKernel>([]() {});
});
Graph.end_recording(Queue);
auto GraphExec = Graph.finalize();

// Add node to main graph followed by sub-graph and another node
experimental::command_graph MainGraph(Queue.get_context(), Dev);
MainGraph.begin_recording(Queue);
auto Node1MainGraph = Queue.submit(
[&](sycl::handler &cgh) { cgh.single_task<class TestKernel>([]() {}); });
auto Node2MainGraph = Queue.submit([&](handler &cgh) {
cgh.depends_on(Node1MainGraph);
cgh.ext_oneapi_graph(GraphExec);
});
auto Node3MainGraph = Queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(Node2MainGraph);
cgh.single_task<class TestKernel>([]() {});
});
MainGraph.end_recording(Queue);

// Finalize main graph and check schedule
auto MainGraphExec = MainGraph.finalize();
auto MainGraphExecImpl = sycl::detail::getSyclObjImpl(MainGraphExec);
auto Schedule = MainGraphExecImpl->getSchedule();
auto ScheduleIt = Schedule.begin();
ASSERT_EQ(Schedule.size(), 4ul);

// The first and fourth nodes should have events associated with MainGraph but
// not graph. The second and third nodes were added as a sub-graph and should
// have events associated with Graph but not MainGraph.
ASSERT_ANY_THROW(
sycl::detail::getSyclObjImpl(Graph)->getEventForNode(*ScheduleIt));
ASSERT_EQ(
sycl::detail::getSyclObjImpl(MainGraph)->getEventForNode(*ScheduleIt),
sycl::detail::getSyclObjImpl(Node1MainGraph));

ScheduleIt++;
ASSERT_ANY_THROW(
sycl::detail::getSyclObjImpl(MainGraph)->getEventForNode(*ScheduleIt));
ASSERT_EQ(sycl::detail::getSyclObjImpl(Graph)->getEventForNode(*ScheduleIt),
sycl::detail::getSyclObjImpl(Node1Graph));

ScheduleIt++;
ASSERT_ANY_THROW(
sycl::detail::getSyclObjImpl(MainGraph)->getEventForNode(*ScheduleIt));
ASSERT_EQ(sycl::detail::getSyclObjImpl(Graph)->getEventForNode(*ScheduleIt),
sycl::detail::getSyclObjImpl(Node2Graph));

ScheduleIt++;
ASSERT_ANY_THROW(
sycl::detail::getSyclObjImpl(Graph)->getEventForNode(*ScheduleIt));
ASSERT_EQ(
sycl::detail::getSyclObjImpl(MainGraph)->getEventForNode(*ScheduleIt),
sycl::detail::getSyclObjImpl(Node3MainGraph));
ASSERT_EQ(Queue.get_context(), MainGraphExecImpl->getContext());
}

0 comments on commit 9368f50

Please sign in to comment.