Skip to content

Commit

Permalink
[SYCL][Graph] Add exceptions on invalid event and queue usage
Browse files Browse the repository at this point in the history
- Throws when waiting on a queue in recording mode
- Throws when waiting on an event from a graph submission
- Throws when calling depends_on with an event outside the graph
- Add tests for these exceptions
  • Loading branch information
Bensuo committed Jul 6, 2023
1 parent d624663 commit 77816b2
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 7 deletions.
6 changes: 6 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,12 @@ class __SYCL_EXPORT handler {
setType(detail::CG::CodeplayHostTask);
}

/// @brief Get the command graph if any associated with this handler. It can
/// come from either the associated queue or from being set explicitly through
/// the appropriate constructor.
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getCommandGraph() const;

public:
handler(const handler &) = delete;
handler(handler &&) = delete;
Expand Down
6 changes: 6 additions & 0 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ void event_impl::wait(std::shared_ptr<sycl::detail::event_impl> Self) {
throw sycl::exception(make_error_code(errc::invalid),
"wait method cannot be used for a discarded event.");

if (MGraph.lock()) {
throw sycl::exception(make_error_code(errc::invalid),
"wait method cannot be used for an event associated "
"with a command graph.");
}

#ifdef XPTI_ENABLE_INSTRUMENTATION
void *TelemetryEvent = nullptr;
uint64_t IId;
Expand Down
17 changes: 17 additions & 0 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::oneapi::experimental::detail {
class graph_impl;
}
class context;
namespace detail {
class plugin;
Expand Down Expand Up @@ -265,6 +268,16 @@ class event_impl {
// Get the sync point associated with this event.
sycl::detail::pi::PiExtSyncPoint getSyncPoint() const { return MSyncPoint; }

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
MGraph = Graph;
}

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getCommandGraph() const {
return MGraph.lock();
}

protected:
// When instrumentation is enabled emits trace event for event wait begin and
// returns the telemetry event generated for the wait
Expand Down Expand Up @@ -311,6 +324,10 @@ class event_impl {
std::mutex MMutex;
std::condition_variable cv;

/// Store the command graph associated with this event, if any.
/// This event is also be stored in the graph so a weak_ptr is used.
std::weak_ptr<ext::oneapi::experimental::detail::graph_impl> MGraph;

// If this event represents a submission to a
// sycl::detail::pi::PiExtCommandBuffer the sync point for that submission is
// stored here.
Expand Down
6 changes: 6 additions & 0 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId);
#endif

if (MGraph) {
throw sycl::exception(make_error_code(errc::invalid),
"wait cannot be called for a queue which is "
"recording to a command graph.");
}

std::vector<std::weak_ptr<event_impl>> WeakEvents;
std::vector<event> SharedEvents;
{
Expand Down
32 changes: 25 additions & 7 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ event handler::finalize() {
// Associate an event with this new node and return the event.
GraphImpl->addEventForNode(EventImpl, NodeImpl);

EventImpl->setCommandGraph(GraphImpl);

return detail::createSyclObjFromImpl<event>(EventImpl);
}

Expand Down Expand Up @@ -877,18 +879,25 @@ void handler::depends_on(event Event) {
throw sycl::exception(make_error_code(errc::invalid),
"Queue operation cannot depend on discarded event.");
}
if (auto Graph = getCommandGraph(); Graph) {
auto EventGraph = EventImpl->getCommandGraph();
if (EventGraph == nullptr) {
throw sycl::exception(
make_error_code(errc::invalid),
"Graph nodes cannot depend on events from outside the graph.");
}
if (EventGraph != Graph) {
throw sycl::exception(
make_error_code(errc::invalid),
"Graph nodes cannot depend on events from another graph.");
}
}
CGData.MEvents.push_back(EventImpl);
}

void handler::depends_on(const std::vector<event> &Events) {
for (const event &Event : Events) {
auto EventImpl = detail::getSyclObjImpl(Event);
if (EventImpl->isDiscarded()) {
throw sycl::exception(
make_error_code(errc::invalid),
"Queue operation cannot depend on discarded event.");
}
CGData.MEvents.push_back(EventImpl);
depends_on(Event);
}
}

Expand Down Expand Up @@ -1063,12 +1072,21 @@ void handler::ext_oneapi_graph(
}
// Associate an event with the subgraph node.
auto SubgraphEvent = std::make_shared<event_impl>();
SubgraphEvent->setCommandGraph(ParentGraph);
ParentGraph->addEventForNode(SubgraphEvent, MSubgraphNode);
} else {
// Set the exec graph for execution during finalize.
MExecGraph = GraphImpl;
}
}

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
handler::getCommandGraph() const {
if (MGraph) {
return MGraph;
}
return MQueue->getCommandGraph();
}

} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
68 changes: 68 additions & 0 deletions sycl/test-e2e/Graph/invalid_depends_on.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that calling handler::depends_on() for events not part of the graph
// throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
ext::oneapi::experimental::command_graph Graph2{Queue.get_context(),
Queue.get_device()};

auto NormalEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel1>([=]() {}); });

Graph2.begin_recording(Queue);

auto OtherGraphEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel2>([=]() {}); });

Graph2.end_recording(Queue);

Graph.begin_recording(Queue);

// Test that depends_on in explicit and record and replay throws from an event
// outside any graph.
try {
auto GraphEvent = Queue.submit([&](handler &CGH) {
CGH.depends_on(NormalEvent);
CGH.single_task<class TestKernel3>([=]() {});
});
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}
try {
Graph.add([&](handler &CGH) {
CGH.depends_on(NormalEvent);
CGH.single_task<class TestKernel4>([=]() {});
});
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}

// Test that depends_on throws from an event from another graph.
try {
auto GraphEvent = Queue.submit([&](handler &CGH) {
CGH.depends_on(OtherGraphEvent);
CGH.single_task<class TestKernel5>([=]() {});
});
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}
try {
Graph.add([&](handler &CGH) {
CGH.depends_on(OtherGraphEvent);
CGH.single_task<class TestKernel6>([=]() {});
});
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}

return 0;
}
29 changes: 29 additions & 0 deletions sycl/test-e2e/Graph/invalid_event_wait.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that waiting on an event returned from a Record and Replay submission
// throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
Graph.begin_recording(Queue);

auto GraphEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel>([=]() {}); });

Graph.end_recording(Queue);

try {
GraphEvent.wait();
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}

return 0;
}
23 changes: 23 additions & 0 deletions sycl/test-e2e/Graph/invalid_queue_wait.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that waiting on a Queue in recording mode throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
Graph.begin_recording(Queue);

try {
Queue.wait();
} catch (const sycl::exception &e) {
assert(e.code() == sycl::errc::invalid);
}

return 0;
}

0 comments on commit 77816b2

Please sign in to comment.