Skip to content

Commit

Permalink
[SYCL][Graph] thread-safe: bug fix after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
mfrancepillois committed Aug 4, 2023
1 parent 079a042 commit 1597546
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 5 deletions.
4 changes: 2 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ class graph_impl {
void makeEdge(std::shared_ptr<node_impl> Src,
std::shared_ptr<node_impl> Dest);


/// Throws an invalid exception if this function is called
/// while a queue is recording commands to the graph.
/// @param ExceptionMsg Message to append to the exception message
Expand All @@ -607,14 +606,15 @@ class graph_impl {
"is currently recording commands to a graph.");
}
}

// Returns the number of nodes in the Graph
// @return Number of nodes in the Graph
size_t getNumberOfNodes() const {
size_t NumberOfNodes = 0;
for (const auto &Node : MRoots) {
NumberOfNodes += Node->depthSearchCount();
}
return NumberOfNodes;
}

private:
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ class queue_impl {

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
std::lock_guard<std::mutex> Lock(MMutex);
MGraph = Graph;
}

Expand Down
61 changes: 60 additions & 1 deletion sycl/test-e2e/Graph/Threading/finalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@

#include <thread>

bool checkExecGraphSchedule(
std::shared_ptr<sycl::ext::oneapi::experimental::detail::exec_graph_impl>
GraphA,
std::shared_ptr<sycl::ext::oneapi::experimental::detail::exec_graph_impl>
GraphB) {
auto ScheduleA = GraphA->getSchedule();
auto ScheduleB = GraphB->getSchedule();
if (ScheduleA.size() != ScheduleB.size())
return false;

std::vector<
std::shared_ptr<sycl::ext::oneapi::experimental::detail::node_impl>>
VScheduleA{std::begin(ScheduleA), std::end(ScheduleA)};
std::vector<
std::shared_ptr<sycl::ext::oneapi::experimental::detail::node_impl>>
VScheduleB{std::begin(ScheduleB), std::end(ScheduleB)};

for (size_t i = 0; i < VScheduleA.size(); i++) {
if (!VScheduleA[i]->isSimilar(VScheduleB[i]))
return false;
}
return true;
}

int main() {
queue Queue;

Expand Down Expand Up @@ -52,14 +76,18 @@ int main() {
auto FinalizeGraph = [&](int ThreadNum) {
SyncPoint.wait();
auto GraphExec = Graph.finalize();
GraphsExecMap.insert(
std::map<int,
exp_ext::command_graph<exp_ext::graph_state::executable>>::
value_type(ThreadNum, GraphExec));
Queue.submit([&](sycl::handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
};

std::vector<std::thread> Threads;
Threads.reserve(NumThreads);

for (unsigned i = 0; i < NumThreads; ++i) {
Threads.emplace_back(FinalizeGraph);
Threads.emplace_back(FinalizeGraph, i);
}

for (unsigned i = 0; i < NumThreads; ++i) {
Expand All @@ -73,6 +101,37 @@ int main() {
Queue.copy(PtrC, DataC.data(), Size);
Queue.wait_and_throw();

// Ref computation
queue QueueRef{Queue.get_context(), Queue.get_device()};
exp_ext::command_graph GraphRef{Queue.get_context(), Queue.get_device()};

T *PtrARef = malloc_device<T>(Size, QueueRef);
T *PtrBRef = malloc_device<T>(Size, QueueRef);
T *PtrCRef = malloc_device<T>(Size, QueueRef);

QueueRef.copy(DataA.data(), PtrARef, Size);
QueueRef.copy(DataB.data(), PtrBRef, Size);
QueueRef.copy(DataC.data(), PtrCRef, Size);
QueueRef.wait_and_throw();

GraphRef.begin_recording(QueueRef);
run_kernels_usm(QueueRef, Size, PtrA, PtrB, PtrC);
GraphRef.end_recording();

for (unsigned i = 0; i < NumThreads; ++i) {
auto GraphExecRef = GraphRef.finalize();
QueueRef.submit(
[&](sycl::handler &CGH) { CGH.ext_oneapi_graph(GraphExecRef); });
auto GraphExecImpl =
sycl::detail::getSyclObjImpl(GraphsExecMap.find(i)->second);
auto GraphExecRefImpl = sycl::detail::getSyclObjImpl(GraphExecRef);
assert(checkExecGraphSchedule(GraphExecImpl, GraphExecRefImpl));
}

free(PtrARef, QueueRef);
free(PtrBRef, QueueRef);
free(PtrCRef, QueueRef);

free(PtrA, Queue);
free(PtrB, Queue);
free(PtrC, Queue);
Expand Down
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Threading/update.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// REQUIRES: level_zero, gpu, TEMPORARY_DISABLED
// Disabled as thread safety not yet implemented
// Disabled as Update feature is not yet implemented

// RUN: %clangxx -pthread -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %{run} %t.out
Expand Down
2 changes: 1 addition & 1 deletion sycl/unittests/Extensions/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ void addMemcpy2D(experimental::detail::modifiable_command_graph &G, queue &Q,
}
ASSERT_EQ(ExceptionCode, sycl::errc::invalid);
}

bool depthSearchSuccessorCheck(
std::shared_ptr<sycl::ext::oneapi::experimental::detail::node_impl> Node) {
if (Node->MSuccessors.size() > 1)
Expand Down

0 comments on commit 1597546

Please sign in to comment.