diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 91f3a9aa336fd..2cff525b2dd63 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -1526,6 +1526,12 @@ class __SYCL_EXPORT handler { setType(detail::CG::CodeplayHostTask); } + /// @brief Get the command graph if any associated with this handler. It can + /// come from either the associated queue or from being set explicitly through + /// the appropriate constructor. + std::shared_ptr + getCommandGraph() const; + public: handler(const handler &) = delete; handler(handler &&) = delete; diff --git a/sycl/source/detail/event_impl.cpp b/sycl/source/detail/event_impl.cpp index 75807863111df..7ff666cdbb419 100644 --- a/sycl/source/detail/event_impl.cpp +++ b/sycl/source/detail/event_impl.cpp @@ -223,6 +223,12 @@ void event_impl::wait(std::shared_ptr Self) { throw sycl::exception(make_error_code(errc::invalid), "wait method cannot be used for a discarded event."); + if (MGraph.lock()) { + throw sycl::exception(make_error_code(errc::invalid), + "wait method cannot be used for an event associated " + "with a command graph."); + } + #ifdef XPTI_ENABLE_INSTRUMENTATION void *TelemetryEvent = nullptr; uint64_t IId; diff --git a/sycl/source/detail/event_impl.hpp b/sycl/source/detail/event_impl.hpp index 47d52d25e643e..073cb9e23d53f 100644 --- a/sycl/source/detail/event_impl.hpp +++ b/sycl/source/detail/event_impl.hpp @@ -23,6 +23,9 @@ namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext::oneapi::experimental::detail { +class graph_impl; +} class context; namespace detail { class plugin; @@ -265,6 +268,16 @@ class event_impl { // Get the sync point associated with this event. sycl::detail::pi::PiExtSyncPoint getSyncPoint() const { return MSyncPoint; } + void setCommandGraph( + std::shared_ptr Graph) { + MGraph = Graph; + } + + std::shared_ptr + getCommandGraph() const { + return MGraph.lock(); + } + protected: // When instrumentation is enabled emits trace event for event wait begin and // returns the telemetry event generated for the wait @@ -311,6 +324,10 @@ class event_impl { std::mutex MMutex; std::condition_variable cv; + /// Store the command graph associated with this event, if any. + /// This event is also be stored in the graph so a weak_ptr is used. + std::weak_ptr MGraph; + // If this event represents a submission to a // sycl::detail::pi::PiExtCommandBuffer the sync point for that submission is // stored here. diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index 8f9d3f8284fdd..adf4a660d2c9f 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -478,6 +478,12 @@ void queue_impl::wait(const detail::code_location &CodeLoc) { TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId); #endif + if (MGraph) { + throw sycl::exception(make_error_code(errc::invalid), + "wait cannot be called for a queue which is " + "recording to a command graph."); + } + std::vector> WeakEvents; std::vector SharedEvents; { diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 5dbd1f2dd466d..359456507729b 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -438,6 +438,8 @@ event handler::finalize() { // Associate an event with this new node and return the event. GraphImpl->addEventForNode(EventImpl, NodeImpl); + EventImpl->setCommandGraph(GraphImpl); + return detail::createSyclObjFromImpl(EventImpl); } @@ -877,18 +879,25 @@ void handler::depends_on(event Event) { throw sycl::exception(make_error_code(errc::invalid), "Queue operation cannot depend on discarded event."); } + if (auto Graph = getCommandGraph(); Graph) { + auto EventGraph = EventImpl->getCommandGraph(); + if (EventGraph == nullptr) { + throw sycl::exception( + make_error_code(errc::invalid), + "Graph nodes cannot depend on events from outside the graph."); + } + if (EventGraph != Graph) { + throw sycl::exception( + make_error_code(errc::invalid), + "Graph nodes cannot depend on events from another graph."); + } + } CGData.MEvents.push_back(EventImpl); } void handler::depends_on(const std::vector &Events) { for (const event &Event : Events) { - auto EventImpl = detail::getSyclObjImpl(Event); - if (EventImpl->isDiscarded()) { - throw sycl::exception( - make_error_code(errc::invalid), - "Queue operation cannot depend on discarded event."); - } - CGData.MEvents.push_back(EventImpl); + depends_on(Event); } } @@ -1063,6 +1072,7 @@ void handler::ext_oneapi_graph( } // Associate an event with the subgraph node. auto SubgraphEvent = std::make_shared(); + SubgraphEvent->setCommandGraph(ParentGraph); ParentGraph->addEventForNode(SubgraphEvent, MSubgraphNode); } else { // Set the exec graph for execution during finalize. @@ -1070,5 +1080,13 @@ void handler::ext_oneapi_graph( } } +std::shared_ptr +handler::getCommandGraph() const { + if (MGraph) { + return MGraph; + } + return MQueue->getCommandGraph(); +} + } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/test-e2e/Graph/invalid_depends_on.cpp b/sycl/test-e2e/Graph/invalid_depends_on.cpp new file mode 100644 index 0000000000000..bd0f2424c7383 --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_depends_on.cpp @@ -0,0 +1,68 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that calling handler::depends_on() for events not part of the graph +// throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + ext::oneapi::experimental::command_graph Graph2{Queue.get_context(), + Queue.get_device()}; + + auto NormalEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph2.begin_recording(Queue); + + auto OtherGraphEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph2.end_recording(Queue); + + Graph.begin_recording(Queue); + + // Test that depends_on in explicit and record and replay throws from an event + // outside any graph. + try { + auto GraphEvent = Queue.submit([&](handler &CGH) { + CGH.depends_on(NormalEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + try { + Graph.add([&](handler &CGH) { + CGH.depends_on(NormalEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + + // Test that depends_on throws from an event from another graph. + try { + auto GraphEvent = Queue.submit([&](handler &CGH) { + CGH.depends_on(OtherGraphEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + try { + Graph.add([&](handler &CGH) { + CGH.depends_on(OtherGraphEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + + return 0; +} diff --git a/sycl/test-e2e/Graph/invalid_event_wait.cpp b/sycl/test-e2e/Graph/invalid_event_wait.cpp new file mode 100644 index 0000000000000..63b6b74da7bc1 --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_event_wait.cpp @@ -0,0 +1,29 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that waiting on an event returned from a Record and Replay submission +// throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + Graph.begin_recording(Queue); + + auto GraphEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph.end_recording(Queue); + + try { + GraphEvent.wait(); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + + return 0; +} diff --git a/sycl/test-e2e/Graph/invalid_queue_wait.cpp b/sycl/test-e2e/Graph/invalid_queue_wait.cpp new file mode 100644 index 0000000000000..f69da9d96444a --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_queue_wait.cpp @@ -0,0 +1,23 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that waiting on a Queue in recording mode throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + Graph.begin_recording(Queue); + + try { + Queue.wait(); + } catch (const sycl::exception &e) { + assert(e.code() == sycl::errc::invalid); + } + + return 0; +}