Skip to content

Commit

Permalink
[SYCL][Graph] Add exceptions on invalid event and queue usage (#250)
Browse files Browse the repository at this point in the history
- Throws when waiting on a queue in recording mode
- Throws when waiting on an event from a graph submission
- Throws when calling depends_on with an event outside the graph
- Add tests for these exceptions

---------

Co-authored-by: Ewan Crawford <ewan@codeplay.com>
  • Loading branch information
Bensuo and EwanC authored Jul 11, 2023
1 parent c7389c9 commit 7d88887
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 7 deletions.
6 changes: 6 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,12 @@ class __SYCL_EXPORT handler {
setType(detail::CG::CodeplayHostTask);
}

/// @brief Get the command graph if any associated with this handler. It can
/// come from either the associated queue or from being set explicitly through
/// the appropriate constructor.
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getCommandGraph() const;

public:
handler(const handler &) = delete;
handler(handler &&) = delete;
Expand Down
6 changes: 6 additions & 0 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ void event_impl::wait(std::shared_ptr<sycl::detail::event_impl> Self) {
throw sycl::exception(make_error_code(errc::invalid),
"wait method cannot be used for a discarded event.");

if (MGraph.lock()) {
throw sycl::exception(make_error_code(errc::invalid),
"wait method cannot be used for an event associated "
"with a command graph.");
}

#ifdef XPTI_ENABLE_INSTRUMENTATION
void *TelemetryEvent = nullptr;
uint64_t IId;
Expand Down
17 changes: 17 additions & 0 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::oneapi::experimental::detail {
class graph_impl;
}
class context;
namespace detail {
class plugin;
Expand Down Expand Up @@ -265,6 +268,16 @@ class event_impl {
// Get the sync point associated with this event.
sycl::detail::pi::PiExtSyncPoint getSyncPoint() const { return MSyncPoint; }

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
MGraph = Graph;
}

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getCommandGraph() const {
return MGraph.lock();
}

protected:
// When instrumentation is enabled emits trace event for event wait begin and
// returns the telemetry event generated for the wait
Expand Down Expand Up @@ -311,6 +324,10 @@ class event_impl {
std::mutex MMutex;
std::condition_variable cv;

/// Store the command graph associated with this event, if any.
/// This event is also be stored in the graph so a weak_ptr is used.
std::weak_ptr<ext::oneapi::experimental::detail::graph_impl> MGraph;

// If this event represents a submission to a
// sycl::detail::pi::PiExtCommandBuffer the sync point for that submission is
// stored here.
Expand Down
6 changes: 6 additions & 0 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId);
#endif

if (MGraph) {
throw sycl::exception(make_error_code(errc::invalid),
"wait cannot be called for a queue which is "
"recording to a command graph.");
}

std::vector<std::weak_ptr<event_impl>> WeakEvents;
std::vector<event> SharedEvents;
{
Expand Down
32 changes: 25 additions & 7 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ event handler::finalize() {
// Associate an event with this new node and return the event.
GraphImpl->addEventForNode(EventImpl, NodeImpl);

EventImpl->setCommandGraph(GraphImpl);

return detail::createSyclObjFromImpl<event>(EventImpl);
}

Expand Down Expand Up @@ -877,18 +879,25 @@ void handler::depends_on(event Event) {
throw sycl::exception(make_error_code(errc::invalid),
"Queue operation cannot depend on discarded event.");
}
if (auto Graph = getCommandGraph(); Graph) {
auto EventGraph = EventImpl->getCommandGraph();
if (EventGraph == nullptr) {
throw sycl::exception(
make_error_code(errc::invalid),
"Graph nodes cannot depend on events from outside the graph.");
}
if (EventGraph != Graph) {
throw sycl::exception(
make_error_code(errc::invalid),
"Graph nodes cannot depend on events from another graph.");
}
}
CGData.MEvents.push_back(EventImpl);
}

void handler::depends_on(const std::vector<event> &Events) {
for (const event &Event : Events) {
auto EventImpl = detail::getSyclObjImpl(Event);
if (EventImpl->isDiscarded()) {
throw sycl::exception(
make_error_code(errc::invalid),
"Queue operation cannot depend on discarded event.");
}
CGData.MEvents.push_back(EventImpl);
depends_on(Event);
}
}

Expand Down Expand Up @@ -1063,12 +1072,21 @@ void handler::ext_oneapi_graph(
}
// Associate an event with the subgraph node.
auto SubgraphEvent = std::make_shared<event_impl>();
SubgraphEvent->setCommandGraph(ParentGraph);
ParentGraph->addEventForNode(SubgraphEvent, MSubgraphNode);
} else {
// Set the exec graph for execution during finalize.
MExecGraph = GraphImpl;
}
}

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
handler::getCommandGraph() const {
if (MGraph) {
return MGraph;
}
return MQueue->getCommandGraph();
}

} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
79 changes: 79 additions & 0 deletions sycl/test-e2e/Graph/invalid_depends_on.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that calling handler::depends_on() for events not part of the graph
// throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
ext::oneapi::experimental::command_graph Graph2{Queue.get_context(),
Queue.get_device()};

auto NormalEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel1>([=]() {}); });

Graph2.begin_recording(Queue);

auto OtherGraphEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel2>([=]() {}); });

Graph2.end_recording(Queue);

Graph.begin_recording(Queue);

// Test that depends_on in explicit and record and replay throws from an event
// outside any graph.

std::error_code ErrorCode = make_error_code(sycl::errc::success);
try {
auto GraphEvent = Queue.submit([&](handler &CGH) {
CGH.depends_on(NormalEvent);
CGH.single_task<class TestKernel3>([=]() {});
});
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

ErrorCode = make_error_code(sycl::errc::success);
try {
Graph.add([&](handler &CGH) {
CGH.depends_on(NormalEvent);
CGH.single_task<class TestKernel4>([=]() {});
});
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

// Test that depends_on throws from an event from another graph.
ErrorCode = make_error_code(sycl::errc::success);
try {
auto GraphEvent = Queue.submit([&](handler &CGH) {
CGH.depends_on(OtherGraphEvent);
CGH.single_task<class TestKernel5>([=]() {});
});
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

ErrorCode = make_error_code(sycl::errc::success);
try {
Graph.add([&](handler &CGH) {
CGH.depends_on(OtherGraphEvent);
CGH.single_task<class TestKernel6>([=]() {});
});
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

return 0;
}
31 changes: 31 additions & 0 deletions sycl/test-e2e/Graph/invalid_event_wait.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that waiting on an event returned from a Record and Replay submission
// throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
Graph.begin_recording(Queue);

auto GraphEvent = Queue.submit(
[&](handler &CGH) { CGH.single_task<class TestKernel>([=]() {}); });

Graph.end_recording(Queue);

std::error_code ErrorCode = make_error_code(sycl::errc::success);
try {
GraphEvent.wait();
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

return 0;
}
26 changes: 26 additions & 0 deletions sycl/test-e2e/Graph/invalid_queue_wait.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// Tests that waiting on a Queue in recording mode throws.

#include "graph_common.hpp"

int main() {
queue Queue;

ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
Queue.get_device()};
Graph.begin_recording(Queue);

std::error_code ErrorCode = make_error_code(sycl::errc::success);

try {
Queue.wait();
} catch (const sycl::exception &e) {
ErrorCode = e.code();
}
assert(ErrorCode == sycl::errc::invalid);

return 0;
}
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4575,6 +4575,7 @@ _ZNK4sycl3_V17context8get_infoINS0_4info7context32atomic_memory_scope_capabiliti
_ZNK4sycl3_V17context8get_infoINS0_4info7context7devicesEEENS0_6detail20is_context_info_descIT_E11return_typeEv
_ZNK4sycl3_V17context8get_infoINS0_4info7context8platformEEENS0_6detail20is_context_info_descIT_E11return_typeEv
_ZNK4sycl3_V17context9getNativeEv
_ZNK4sycl3_V17handler15getCommandGraphEv
_ZNK4sycl3_V17handler17getContextImplPtrEv
_ZNK4sycl3_V17handler27isStateExplicitKernelBundleEv
_ZNK4sycl3_V17handler30getOrInsertHandlerKernelBundleEb
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@
?getChannelType@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ
?getChannelType@image_impl@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ
?getChannelType@image_plain@detail@_V1@sycl@@IEBA?AW4image_channel_type@34@XZ
?getCommandGraph@handler@_V1@sycl@@AEBA?AV?$shared_ptr@Vgraph_impl@detail@experimental@oneapi@ext@_V1@sycl@@@std@@XZ
?getContextImplPtr@handler@_V1@sycl@@AEBAAEBV?$shared_ptr@Vcontext_impl@detail@_V1@sycl@@@std@@XZ
?getCurrentDSODir@OSUtil@detail@_V1@sycl@@SA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@XZ
?getDeviceFromHandler@detail@_V1@sycl@@YA?AVdevice@23@AEAVhandler@23@@Z
Expand Down

0 comments on commit 7d88887

Please sign in to comment.