Skip to content

Commit

Permalink
[SYCL][Graph] Makes command graph functions thread-safe (#265)
Browse files Browse the repository at this point in the history
* [SYCL][Graph] Makes command graph functions thread-safe

Addresses comments made on the first PR commit.
Mutexes are now added to Graph implementation entry points
instead of end points as was the case in the previous commit.
Adds "build_pthread_inc" lit test macro to facilitate the
compilation of the threading tests.
Removes std::barrier (std-20) dependency in threading tests.

Addresses Issue: #85

* [SYCL][Graph] Makes command graph functions thread-safe

Moves threading tests that do not require a device to run to unitests

* Update sycl/source/detail/graph_impl.cpp

Co-authored-by: Ben Tracy <ben.tracy@codeplay.com>

* [SYCL][Graph] Makes command graph functions thread-safe

Adds some comments.

* Update sycl/source/handler.cpp

Co-authored-by: Pablo Reble <pablo.reble@intel.com>

* Update sycl/source/detail/graph_impl.hpp

Co-authored-by: Ewan Crawford <ewan@codeplay.com>

* [SYCL][Graph] Makes command graph functions thread-safe

Adds dedidacted sub-class to unitests for multi-threading unitests

* [SYCL][Graph] Makes command graph functions thread-safe

Adds comments

* [SYCL][Graph] thread-safe: bug fix after rebase

---------

Co-authored-by: Ben Tracy <ben.tracy@codeplay.com>
Co-authored-by: Pablo Reble <pablo.reble@intel.com>
Co-authored-by: Ewan Crawford <ewan@codeplay.com>
  • Loading branch information
4 people authored Aug 4, 2023
1 parent b99238b commit 0d6bf2a
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 210 deletions.
18 changes: 16 additions & 2 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
checkForRequirement(Req, NodePtr, UniqueDeps);
}
}

// Add any nodes specified by event dependencies into the dependency list
for (auto Dep : CommandGroup->getEvents()) {
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
Expand Down Expand Up @@ -474,6 +473,8 @@ void exec_graph_impl::createCommandBuffers(sycl::device Device) {
}

exec_graph_impl::~exec_graph_impl() {
WriteLock Lock(MMutex);

// clear all recording queue if not done before (no call to end_recording)
MGraphImpl->clearQueues();

Expand All @@ -499,6 +500,8 @@ exec_graph_impl::~exec_graph_impl() {
sycl::event
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
sycl::detail::CG::StorageInitHelper CGData) {
WriteLock Lock(MMutex);

auto CreateNewEvent([&]() {
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
NewEvent->setContextImpl(Queue->getContextImplPtr());
Expand Down Expand Up @@ -612,6 +615,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}
Expand All @@ -624,6 +628,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl =
impl->add(impl, CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
Expand All @@ -635,6 +640,7 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
std::shared_ptr<detail::node_impl> ReceiverImpl =
sycl::detail::getSyclObjImpl(Dest);

graph_impl::WriteLock Lock(impl->MMutex);
impl->makeEdge(SenderImpl, ReceiverImpl);
}

Expand Down Expand Up @@ -666,6 +672,7 @@ bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {

if (QueueImpl->getCommandGraph() == nullptr) {
QueueImpl->setCommandGraph(impl);
graph_impl::WriteLock Lock(impl->MMutex);
impl->addQueue(QueueImpl);
return true;
}
Expand All @@ -687,12 +694,16 @@ bool modifiable_command_graph::begin_recording(
return QueueStateChanged;
}

bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }
bool modifiable_command_graph::end_recording() {
graph_impl::WriteLock Lock(impl->MMutex);
return impl->clearQueues();
}

bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
graph_impl::WriteLock Lock(impl->MMutex);
impl->removeQueue(QueueImpl);
return true;
}
Expand All @@ -719,6 +730,9 @@ executable_command_graph::executable_command_graph(
const std::shared_ptr<detail::graph_impl> &Graph, const sycl::context &Ctx)
: MTag(rand()),
impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph)) {
// Graph is read and written in this scope so we lock
// this graph with full priviledges.
graph_impl::WriteLock Lock(Graph->MMutex);
finalizeImpl(); // Create backend representation for executable graph
}

Expand Down
47 changes: 42 additions & 5 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <functional>
#include <list>
#include <set>
#include <shared_mutex>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -326,6 +327,16 @@ class node_impl {
return true;
}

/// Recusively computes the number of successor nodes
/// @return number of successor nodes
size_t depthSearchCount() const {
size_t NumberOfNodes = 1;
for (const auto &Succ : MSuccessors) {
NumberOfNodes += Succ->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Creates a copy of the node's CG by casting to it's actual type, then using
/// that to copy construct and create a new unique ptr from that copy.
Expand All @@ -339,6 +350,12 @@ class node_impl {
/// Implementation details of command_graph<modifiable>.
class graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param SyclContext Context to use for graph.
/// @param SyclDevice Device to create nodes with.
Expand All @@ -352,10 +369,6 @@ class graph_impl {
}
}

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);

/// Remove node from list of root nodes.
/// @param Root Node to remove from list of root nodes.
void removeRoot(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -429,13 +442,13 @@ class graph_impl {
/// @return Event associated with node.
std::shared_ptr<sycl::detail::event_impl>
getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
ReadLock Lock(MMutex);
if (auto EventImpl = std::find_if(
MEventsMap.begin(), MEventsMap.end(),
[NodeImpl](auto &it) { return it.second == NodeImpl; });
EventImpl != MEventsMap.end()) {
return EventImpl->first;
}

throw sycl::exception(
sycl::make_error_code(errc::invalid),
"No event has been recorded for the specified graph node");
Expand Down Expand Up @@ -594,6 +607,16 @@ class graph_impl {
}
}

// Returns the number of nodes in the Graph
// @return Number of nodes in the Graph
size_t getNumberOfNodes() const {
size_t NumberOfNodes = 0;
for (const auto &Node : MRoots) {
NumberOfNodes += Node->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
/// @param NodeFunc A function which receives as input a node in the graph to
Expand Down Expand Up @@ -632,11 +655,21 @@ class graph_impl {
/// Controls whether we skip the cycle checks in makeEdge, set by the presence
/// of the no_cycle_check property on construction.
bool MSkipCycleChecks = false;

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);
};

/// Class representing the implementation of command_graph<executable>.
class exec_graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param Context Context to create graph with.
/// @param GraphImpl Modifiable graph implementation to create with.
Expand Down Expand Up @@ -739,6 +772,10 @@ class exec_graph_impl {
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
/// Thread-safe implementation note: in the current implementation
/// multiple exec_graph_impl can reference the same graph_impl object.
/// This specificity must be taken into account when trying to lock
/// the graph_impl mutex from an exec_graph_impl to avoid deadlock.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ class queue_impl {

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
std::lock_guard<std::mutex> Lock(MMutex);
MGraph = Graph;
}

Expand Down
18 changes: 18 additions & 0 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ event handler::finalize() {
std::shared_ptr<ext::oneapi::experimental::detail::node_impl> NodeImpl =
nullptr;

// GraphImpl is read and written in this scope so we lock this graph
// with full priviledges.
ext::oneapi::experimental::detail::graph_impl::WriteLock Lock(
GraphImpl->MMutex);

// Create a new node in the graph representing this command-group
if (MQueue->isInOrder()) {
// In-order queues create implicit linear dependencies between nodes.
Expand Down Expand Up @@ -1047,15 +1052,28 @@ void handler::ext_oneapi_graph(
Graph) {
MCGType = detail::CG::ExecCommandBuffer;
auto GraphImpl = detail::getSyclObjImpl(Graph);
// GraphImpl is only read in this scope so we lock this graph for read only
ext::oneapi::experimental::detail::graph_impl::ReadLock Lock(
GraphImpl->MMutex);

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> ParentGraph;
if (MQueue) {
ParentGraph = MQueue->getCommandGraph();
} else {
ParentGraph = MGraph;
}

ext::oneapi::experimental::detail::graph_impl::WriteLock ParentLock;
// If a parent graph is set that means we are adding or recording a subgraph
if (ParentGraph) {
// ParentGraph is read and written in this scope so we lock this graph
// with full priviledges.
// We only lock for Record&Replay API because the graph has already been
// lock if this function was called from the explicit API function add
if (MQueue) {
ParentLock = ext::oneapi::experimental::detail::graph_impl::WriteLock(
ParentGraph->MMutex);
}
// Store the node representing the subgraph in the handler so that we can
// return it to the user later.
MSubgraphNode = ParentGraph->addSubgraphNodes(GraphImpl->getSchedule());
Expand Down
67 changes: 0 additions & 67 deletions sycl/test-e2e/Graph/Threading/begin_end_recording.cpp

This file was deleted.

60 changes: 0 additions & 60 deletions sycl/test-e2e/Graph/Threading/explicit_add_nodes.cpp

This file was deleted.

Loading

0 comments on commit 0d6bf2a

Please sign in to comment.