diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index f1c8de2c2e25c..7f6cd6eba1563 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -81,6 +81,17 @@ event queue_impl::memset(const std::shared_ptr &Self, // Emit a begin/end scope for this call PrepareNotify.scopedNotify((uint16_t)xpti::trace_point_type_t::task_begin); #endif + // If we have a command graph set we need to capture the memset through normal + // queue submission rather than execute the memset directly. + if (MGraph.lock()) { + return submit( + [&](handler &CGH) { + CGH.depends_on(DepEvents); + CGH.memset(Ptr, Value, Count); + }, + Self, {}); + } + if (MHasDiscardEventsSupport) { MemoryManager::fill_usm(Ptr, Self, Count, Value, getOrWaitEvents(DepEvents, MContext), nullptr); diff --git a/sycl/test-e2e/Graph/RecordReplay/usm_memset_shortcut.cpp b/sycl/test-e2e/Graph/RecordReplay/usm_memset_shortcut.cpp new file mode 100644 index 0000000000000..d170c9607d821 --- /dev/null +++ b/sycl/test-e2e/Graph/RecordReplay/usm_memset_shortcut.cpp @@ -0,0 +1,47 @@ +// REQUIRES: cuda || level_zero, gpu +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// Extra run to check for leaks in Level Zero using ZE_DEBUG +// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %} +// +// CHECK-NOT: LEAK +// +// Tests adding a USM memset queue shortcut operation as a graph node. + +#include "../graph_common.hpp" + +int main() { + + queue Queue; + + exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()}; + + const size_t N = 10; + unsigned char *Arr = malloc_device(N, Queue); + + int Value = 77; + Graph.begin_recording(Queue); + auto Init = Queue.memset(Arr, Value, N); + Queue.submit([&](handler &CGH) { + CGH.depends_on(Init); + CGH.single_task([=]() { + for (int i = 0; i < Size; i++) + Arr[i] = 2 * Arr[i]; + }); + }); + + Graph.end_recording(); + + auto ExecGraph = Graph.finalize(); + + Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecGraph); }).wait(); + + std::vector Output(N); + Queue.memcpy(Output.data(), Arr, N).wait(); + for (int i = 0; i < N; i++) + assert(Output[i] == (Value * 2)); + + sycl::free(Arr, Queue); + + return 0; +}