diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 91f3a9aa336fd..2cff525b2dd63 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -1526,6 +1526,12 @@ class __SYCL_EXPORT handler { setType(detail::CG::CodeplayHostTask); } + /// @brief Get the command graph if any associated with this handler. It can + /// come from either the associated queue or from being set explicitly through + /// the appropriate constructor. + std::shared_ptr + getCommandGraph() const; + public: handler(const handler &) = delete; handler(handler &&) = delete; diff --git a/sycl/source/detail/event_impl.cpp b/sycl/source/detail/event_impl.cpp index 75807863111df..7ff666cdbb419 100644 --- a/sycl/source/detail/event_impl.cpp +++ b/sycl/source/detail/event_impl.cpp @@ -223,6 +223,12 @@ void event_impl::wait(std::shared_ptr Self) { throw sycl::exception(make_error_code(errc::invalid), "wait method cannot be used for a discarded event."); + if (MGraph.lock()) { + throw sycl::exception(make_error_code(errc::invalid), + "wait method cannot be used for an event associated " + "with a command graph."); + } + #ifdef XPTI_ENABLE_INSTRUMENTATION void *TelemetryEvent = nullptr; uint64_t IId; diff --git a/sycl/source/detail/event_impl.hpp b/sycl/source/detail/event_impl.hpp index 47d52d25e643e..073cb9e23d53f 100644 --- a/sycl/source/detail/event_impl.hpp +++ b/sycl/source/detail/event_impl.hpp @@ -23,6 +23,9 @@ namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext::oneapi::experimental::detail { +class graph_impl; +} class context; namespace detail { class plugin; @@ -265,6 +268,16 @@ class event_impl { // Get the sync point associated with this event. sycl::detail::pi::PiExtSyncPoint getSyncPoint() const { return MSyncPoint; } + void setCommandGraph( + std::shared_ptr Graph) { + MGraph = Graph; + } + + std::shared_ptr + getCommandGraph() const { + return MGraph.lock(); + } + protected: // When instrumentation is enabled emits trace event for event wait begin and // returns the telemetry event generated for the wait @@ -311,6 +324,10 @@ class event_impl { std::mutex MMutex; std::condition_variable cv; + /// Store the command graph associated with this event, if any. + /// This event is also be stored in the graph so a weak_ptr is used. + std::weak_ptr MGraph; + // If this event represents a submission to a // sycl::detail::pi::PiExtCommandBuffer the sync point for that submission is // stored here. diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index 8f9d3f8284fdd..adf4a660d2c9f 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -478,6 +478,12 @@ void queue_impl::wait(const detail::code_location &CodeLoc) { TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId); #endif + if (MGraph) { + throw sycl::exception(make_error_code(errc::invalid), + "wait cannot be called for a queue which is " + "recording to a command graph."); + } + std::vector> WeakEvents; std::vector SharedEvents; { diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 5dbd1f2dd466d..359456507729b 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -438,6 +438,8 @@ event handler::finalize() { // Associate an event with this new node and return the event. GraphImpl->addEventForNode(EventImpl, NodeImpl); + EventImpl->setCommandGraph(GraphImpl); + return detail::createSyclObjFromImpl(EventImpl); } @@ -877,18 +879,25 @@ void handler::depends_on(event Event) { throw sycl::exception(make_error_code(errc::invalid), "Queue operation cannot depend on discarded event."); } + if (auto Graph = getCommandGraph(); Graph) { + auto EventGraph = EventImpl->getCommandGraph(); + if (EventGraph == nullptr) { + throw sycl::exception( + make_error_code(errc::invalid), + "Graph nodes cannot depend on events from outside the graph."); + } + if (EventGraph != Graph) { + throw sycl::exception( + make_error_code(errc::invalid), + "Graph nodes cannot depend on events from another graph."); + } + } CGData.MEvents.push_back(EventImpl); } void handler::depends_on(const std::vector &Events) { for (const event &Event : Events) { - auto EventImpl = detail::getSyclObjImpl(Event); - if (EventImpl->isDiscarded()) { - throw sycl::exception( - make_error_code(errc::invalid), - "Queue operation cannot depend on discarded event."); - } - CGData.MEvents.push_back(EventImpl); + depends_on(Event); } } @@ -1063,6 +1072,7 @@ void handler::ext_oneapi_graph( } // Associate an event with the subgraph node. auto SubgraphEvent = std::make_shared(); + SubgraphEvent->setCommandGraph(ParentGraph); ParentGraph->addEventForNode(SubgraphEvent, MSubgraphNode); } else { // Set the exec graph for execution during finalize. @@ -1070,5 +1080,13 @@ void handler::ext_oneapi_graph( } } +std::shared_ptr +handler::getCommandGraph() const { + if (MGraph) { + return MGraph; + } + return MQueue->getCommandGraph(); +} + } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/test-e2e/Graph/invalid_depends_on.cpp b/sycl/test-e2e/Graph/invalid_depends_on.cpp new file mode 100644 index 0000000000000..cd9ee51303c12 --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_depends_on.cpp @@ -0,0 +1,79 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that calling handler::depends_on() for events not part of the graph +// throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + ext::oneapi::experimental::command_graph Graph2{Queue.get_context(), + Queue.get_device()}; + + auto NormalEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph2.begin_recording(Queue); + + auto OtherGraphEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph2.end_recording(Queue); + + Graph.begin_recording(Queue); + + // Test that depends_on in explicit and record and replay throws from an event + // outside any graph. + + std::error_code ErrorCode = make_error_code(sycl::errc::success); + try { + auto GraphEvent = Queue.submit([&](handler &CGH) { + CGH.depends_on(NormalEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + ErrorCode = make_error_code(sycl::errc::success); + try { + Graph.add([&](handler &CGH) { + CGH.depends_on(NormalEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + // Test that depends_on throws from an event from another graph. + ErrorCode = make_error_code(sycl::errc::success); + try { + auto GraphEvent = Queue.submit([&](handler &CGH) { + CGH.depends_on(OtherGraphEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + ErrorCode = make_error_code(sycl::errc::success); + try { + Graph.add([&](handler &CGH) { + CGH.depends_on(OtherGraphEvent); + CGH.single_task([=]() {}); + }); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + return 0; +} diff --git a/sycl/test-e2e/Graph/invalid_event_wait.cpp b/sycl/test-e2e/Graph/invalid_event_wait.cpp new file mode 100644 index 0000000000000..bb7f4ad6bedda --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_event_wait.cpp @@ -0,0 +1,31 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that waiting on an event returned from a Record and Replay submission +// throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + Graph.begin_recording(Queue); + + auto GraphEvent = Queue.submit( + [&](handler &CGH) { CGH.single_task([=]() {}); }); + + Graph.end_recording(Queue); + + std::error_code ErrorCode = make_error_code(sycl::errc::success); + try { + GraphEvent.wait(); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + return 0; +} diff --git a/sycl/test-e2e/Graph/invalid_queue_wait.cpp b/sycl/test-e2e/Graph/invalid_queue_wait.cpp new file mode 100644 index 0000000000000..8ba8c7d1c2125 --- /dev/null +++ b/sycl/test-e2e/Graph/invalid_queue_wait.cpp @@ -0,0 +1,26 @@ +// REQUIRES: level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// Tests that waiting on a Queue in recording mode throws. + +#include "graph_common.hpp" + +int main() { + queue Queue; + + ext::oneapi::experimental::command_graph Graph{Queue.get_context(), + Queue.get_device()}; + Graph.begin_recording(Queue); + + std::error_code ErrorCode = make_error_code(sycl::errc::success); + + try { + Queue.wait(); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); + + return 0; +} diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 293d427094261..099ec8da91d13 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -4575,6 +4575,7 @@ _ZNK4sycl3_V17context8get_infoINS0_4info7context32atomic_memory_scope_capabiliti _ZNK4sycl3_V17context8get_infoINS0_4info7context7devicesEEENS0_6detail20is_context_info_descIT_E11return_typeEv _ZNK4sycl3_V17context8get_infoINS0_4info7context8platformEEENS0_6detail20is_context_info_descIT_E11return_typeEv _ZNK4sycl3_V17context9getNativeEv +_ZNK4sycl3_V17handler15getCommandGraphEv _ZNK4sycl3_V17handler17getContextImplPtrEv _ZNK4sycl3_V17handler27isStateExplicitKernelBundleEv _ZNK4sycl3_V17handler30getOrInsertHandlerKernelBundleEb diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index e19f43d12086b..054c607134f0c 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -1009,6 +1009,7 @@ ?getChannelType@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ ?getChannelType@image_impl@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ ?getChannelType@image_plain@detail@_V1@sycl@@IEBA?AW4image_channel_type@34@XZ +?getCommandGraph@handler@_V1@sycl@@AEBA?AV?$shared_ptr@Vgraph_impl@detail@experimental@oneapi@ext@_V1@sycl@@@std@@XZ ?getContextImplPtr@handler@_V1@sycl@@AEBAAEBV?$shared_ptr@Vcontext_impl@detail@_V1@sycl@@@std@@XZ ?getCurrentDSODir@OSUtil@detail@_V1@sycl@@SA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@XZ ?getDeviceFromHandler@detail@_V1@sycl@@YA?AVdevice@23@AEAVhandler@23@@Z