Skip to content

Commit

Permalink
[SYCL][Graph] enable_shared_from_this refactor (#15195)
Browse files Browse the repository at this point in the history
Use `std::enable_shared_from_this` to remove need for passing a shared
pointer of `this` as a function parameter.

`std::enable_shared_from_this` usage was previously introduced to graph
code in #14453 (comment)
  • Loading branch information
EwanC authored Aug 30, 2024
1 parent 57cf62c commit f04c79b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 52 deletions.
39 changes: 17 additions & 22 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ graph_impl::~graph_impl() {
}

std::shared_ptr<node_impl> graph_impl::addNodesToExits(
const std::shared_ptr<graph_impl> &Impl,
const std::list<std::shared_ptr<node_impl>> &NodeList) {
// Find all input and output nodes from the node list
std::vector<std::shared_ptr<node_impl>> Inputs;
Expand All @@ -327,18 +326,18 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
for (auto &NodeImpl : MNodeStorage) {
if (NodeImpl->MSuccessors.size() == 0) {
for (auto &Input : Inputs) {
NodeImpl->registerSuccessor(Input, NodeImpl);
NodeImpl->registerSuccessor(Input);
}
}
}

// Add all the new nodes to the node storage
for (auto &Node : NodeList) {
MNodeStorage.push_back(Node);
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), Node);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), Node);
}

return this->add(Impl, Outputs);
return this->add(Outputs);
}

void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
Expand All @@ -350,8 +349,7 @@ void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep) {
graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
// Copy deps so we can modify them
auto Deps = Dep;

Expand All @@ -361,17 +359,16 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,

addDepsToNode(NodeImpl, Deps);
// Add an event associated with this explicit node for mixed usage
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);
return NodeImpl;
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
std::function<void(handler &)> CGF,
graph_impl::add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
const std::vector<std::shared_ptr<node_impl>> &Dep) {
(void)Args;
sycl::handler Handler{Impl};
sycl::handler Handler{shared_from_this()};
CGF(Handler);

if (Handler.getType() == sycl::detail::CGType::Barrier) {
Expand All @@ -394,7 +391,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Dep);
NodeImpl->MNDRangeUsed = Handler.impl->MNDRangeUsed;
// Add an event associated with this explicit node for mixed usage
addEventForNode(Impl, std::make_shared<sycl::detail::event_impl>(), NodeImpl);
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);

// Retrieve any dynamic parameters which have been registered in the CGF and
// register the actual nodes with them.
Expand All @@ -414,8 +411,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<sycl::detail::EventImplPtr> Events) {
graph_impl::add(const std::vector<sycl::detail::EventImplPtr> Events) {

std::vector<std::shared_ptr<node_impl>> Deps;

Expand All @@ -430,7 +426,7 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
}
}

return this->add(Impl, Deps);
return this->add(Deps);
}

std::shared_ptr<node_impl>
Expand Down Expand Up @@ -594,7 +590,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
}

// We need to add the edges first before checking for cycles
Src->registerSuccessor(Dest, Src);
Src->registerSuccessor(Dest);

// We can skip cycle checks if either Dest has no successors (cycle not
// possible) or cycle checks have been disabled with the no_cycle_check
Expand Down Expand Up @@ -1061,7 +1057,7 @@ void exec_graph_impl::duplicateNodes() {
// register those as successors with the current copied node
for (auto &NextNode : OriginalNode->MSuccessors) {
auto Successor = NodesMap.at(NextNode.lock());
NodeCopy->registerSuccessor(Successor, NodeCopy);
NodeCopy->registerSuccessor(Successor);
}
}

Expand Down Expand Up @@ -1103,7 +1099,7 @@ void exec_graph_impl::duplicateNodes() {

for (auto &NextNode : SubgraphNode->MSuccessors) {
auto Successor = SubgraphNodesMap.at(NextNode.lock());
NodeCopy->registerSuccessor(Successor, NodeCopy);
NodeCopy->registerSuccessor(Successor);
}
}

Expand Down Expand Up @@ -1137,7 +1133,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all input nodes from the subgraph as successors for this node
// instead
for (auto &Input : Inputs) {
PredNode->registerSuccessor(Input, PredNode);
PredNode->registerSuccessor(Input);
}
}

Expand All @@ -1157,7 +1153,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (auto &Output : Outputs) {
Output->registerSuccessor(SuccNode, Output);
Output->registerSuccessor(SuccNode);
}
}

Expand Down Expand Up @@ -1531,7 +1527,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(impl, DepImpls);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

Expand All @@ -1544,8 +1540,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl =
impl->add(impl, CGF, {}, DepImpls);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

Expand Down
43 changes: 14 additions & 29 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
}

/// Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
class node_impl {
class node_impl : public std::enable_shared_from_this<node_impl> {
public:
using id_type = uint64_t;

Expand Down Expand Up @@ -112,20 +112,15 @@ class node_impl {

/// Add successor to the node.
/// @param Node Node to add as a successor.
/// @param Prev Predecessor to \p node being added as successor.
///
/// \p Prev should be a shared_ptr to an instance of this object, but can't
/// use a raw \p this pointer, so the extra \p Prev parameter is passed.
void registerSuccessor(const std::shared_ptr<node_impl> &Node,
const std::shared_ptr<node_impl> &Prev) {
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
[Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
}) != MSuccessors.end()) {
return;
}
MSuccessors.push_back(Node);
Node->registerPredecessor(Prev);
Node->registerPredecessor(shared_from_this());
}

/// Add predecessor to the node.
Expand Down Expand Up @@ -161,9 +156,10 @@ class node_impl {
/// Construct a node from another node. This will perform a deep-copy of the
/// command group object associated with this node.
node_impl(node_impl &Other)
: MSuccessors(Other.MSuccessors), MPredecessors(Other.MPredecessors),
MCGType(Other.MCGType), MNodeType(Other.MNodeType),
MCommandGroup(Other.getCGCopy()), MSubGraphImpl(Other.MSubGraphImpl) {}
: enable_shared_from_this(Other), MSuccessors(Other.MSuccessors),
MPredecessors(Other.MPredecessors), MCGType(Other.MCGType),
MNodeType(Other.MNodeType), MCommandGroup(Other.getCGCopy()),
MSubGraphImpl(Other.MSubGraphImpl) {}

/// Copy-assignment operator. This will perform a deep-copy of the
/// command group object associated with this node.
Expand Down Expand Up @@ -901,32 +897,26 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create a CGF node in the graph.
/// @param Impl Graph implementation pointer to create a handler with.
/// @param CGF Command-group function to create node with.
/// @param Args Node arguments.
/// @param Dep Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
std::function<void(handler &)> CGF,
add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create an empty node in the graph.
/// @param Impl Graph implementation pointer.
/// @param Dep List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});
add(const std::vector<std::shared_ptr<node_impl>> &Dep = {});

/// Create an empty node in the graph.
/// @param Impl Graph implementation pointer.
/// @param Events List of events associated to this node.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<sycl::detail::EventImplPtr> Events);
add(const std::vector<sycl::detail::EventImplPtr> Events);

/// Add a queue to the set of queues which are currently recording to this
/// graph.
Expand All @@ -951,15 +941,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
bool clearQueues();

/// Associate a sycl event with a node in the graph.
/// @param GraphImpl shared_ptr to Graph impl associated with this event, aka
/// this.
/// @param EventImpl Event to associate with a node in map.
/// @param NodeImpl Node to associate with event in map.
void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
std::shared_ptr<sycl::detail::event_impl> EventImpl,
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
std::shared_ptr<node_impl> NodeImpl) {
if (!(EventImpl->getCommandGraph()))
EventImpl->setCommandGraph(GraphImpl);
EventImpl->setCommandGraph(shared_from_this());
MEventsMap[EventImpl] = NodeImpl;
}

Expand Down Expand Up @@ -1238,12 +1225,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
void addRoot(const std::shared_ptr<node_impl> &Root);

/// Adds nodes to the exit nodes of this graph.
/// @param Impl Graph implementation pointer.
/// @param NodeList List of nodes from sub-graph in schedule order.
/// @return An empty node is used to schedule dependencies on this sub-graph.
std::shared_ptr<node_impl>
addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
const std::list<std::shared_ptr<node_impl>> &NodeList);
addNodesToExits(const std::list<std::shared_ptr<node_impl>> &NodeList);

/// Adds dependencies for a new node, if it has no deps it will be
/// added as a root node.
Expand All @@ -1253,7 +1238,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
const std::vector<std::shared_ptr<node_impl>> &Deps) {
if (!Deps.empty()) {
for (auto &N : Deps) {
N->registerSuccessor(Node, N);
N->registerSuccessor(Node);
this->removeRoot(Node);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ event handler::finalize() {
}

// Associate an event with this new node and return the event.
GraphImpl->addEventForNode(GraphImpl, EventImpl, NodeImpl);
GraphImpl->addEventForNode(EventImpl, NodeImpl);

NodeImpl->MNDRangeUsed = impl->MNDRangeUsed;

Expand Down

0 comments on commit f04c79b

Please sign in to comment.