Skip to content

Commit

Permalink
[SYCL][Graph] Add implicit queue recording mechanism (#14453)
Browse files Browse the repository at this point in the history
Updates the semantics of the record & replay API to allow queues to be
implicitly set to recording mode when submitting command-groups
that contains dependencies from other recording queues.

---------

Co-authored-by: Ben Tracy <ben.tracy@codeplay.com>
  • Loading branch information
fabiomestre and Bensuo authored Aug 8, 2024
1 parent dd27ef2 commit bdecdd2
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 20 deletions.
39 changes: 38 additions & 1 deletion sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,7 @@ The alternative `queue_state::recording` state is used for graph construction.
Instead of being scheduled for execution, command-groups submitted to the queue
are recorded to a graph object as new nodes for each submission. After recording
has finished and the queue returns to the executing state, the recorded commands are
not then executed, they are transparent to any following queue operations. The state
not executed, they are transparent to any following queue operations. The state
of a queue can be queried with `queue::ext_oneapi_get_state()`.

.Queue State Diagram
Expand All @@ -1453,6 +1453,43 @@ graph LR
Recording -->|End Recording| Executing
....

==== Transitive Queue Recording

Submitting a command-group to a queue in the executable state can implicitly
change its state to `queue_state::recording`. This will occur when the
command-group depends on an event that has been returned by a queue in the
recording state. The change of state happens before the command-group is
submitted to the device (i.e. a new graph node will be created for that command-group).

A queue whose state has been set to `queue_state::recording` using this
mechanism, will behave as if it had been passed as an argument to
`command_graph::begin_recording()`. In particular, its state will not change
again until `command_graph::end_recording()` is called.

The recording properties of the queue whose event triggered the state change
will also be inherited (i.e. any properties passed to the original call of
`command_graph::begin_recording()` will be inherited by the queue whose state
is being transitioned).

===== Example

[source,c++]
----
// q1 state is set to recording.
graph.begin_recording(q1);
// Node is added to the graph by submitting to a recording queue.
auto e1 = q1.single_task(...);
// Since there is a dependency on e1 which was created by a queue being
// recorded, q2 immediately enters record mode, and a new node is created
// with an edge between e1 and e2.
auto e2 = q2.single_task(e1, ...);
// Ends recording on q1 and q2.
graph.end_recording();
----

==== Queue Properties

:queue-properties: https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:queue-properties
Expand Down
30 changes: 13 additions & 17 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,15 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
return Events;
}

void graph_impl::beginRecording(
std::shared_ptr<sycl::detail::queue_impl> Queue) {
graph_impl::WriteLock Lock(MMutex);
if (Queue->getCommandGraph() == nullptr) {
Queue->setCommandGraph(shared_from_this());
addQueue(Queue);
}
}

// Check if nodes are empty and if so loop back through predecessors until we
// find the real dependency.
void exec_graph_impl::findRealDeps(
Expand Down Expand Up @@ -1584,27 +1593,14 @@ void modifiable_command_graph::begin_recording(
"can NOT be recorded.");
}

if (QueueImpl->get_context() != impl->getContext()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose context "
"differs from the graph context.");
}
if (QueueImpl->get_device() != impl->getDevice()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose device "
"differs from the graph device.");
}

if (QueueImpl->getCommandGraph() == nullptr) {
QueueImpl->setCommandGraph(impl);
graph_impl::WriteLock Lock(impl->MMutex);
impl->addQueue(QueueImpl);
}
if (QueueImpl->getCommandGraph() != impl) {
auto QueueGraph = QueueImpl->getCommandGraph();
if (QueueGraph != nullptr && QueueGraph != impl) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue which is already "
"recording to a different graph.");
}

impl->beginRecording(QueueImpl);
}

void modifiable_command_graph::begin_recording(
Expand Down
7 changes: 6 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ class partition {
};

/// Implementation details of command_graph<modifiable>.
class graph_impl {
class graph_impl : public std::enable_shared_from_this<graph_impl> {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;
Expand Down Expand Up @@ -1194,6 +1194,11 @@ class graph_impl {
std::vector<sycl::detail::EventImplPtr>
getExitNodesEvents(std::weak_ptr<sycl::detail::queue_impl> Queue);

/// Sets the Queue state to queue_state::recording. Adds the queue to the list
/// of recording queues associated with this graph.
/// @param[in] Queue The queue to be recorded from.
void beginRecording(std::shared_ptr<sycl::detail::queue_impl> Queue);

/// Store the last barrier node that was submitted to the queue.
/// @param[in] Queue The queue the barrier was recorded from.
/// @param[in] BarrierNodeImpl The created barrier node.
Expand Down
41 changes: 40 additions & 1 deletion sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,47 @@ void handler::depends_on(const detail::EventImplPtr &EventImpl) {
throw sycl::exception(make_error_code(errc::invalid),
"Queue operation cannot depend on discarded event.");
}

auto EventGraph = EventImpl->getCommandGraph();
if (MQueue && EventGraph) {
auto QueueGraph = MQueue->getCommandGraph();

if (EventGraph->getContext() != MQueue->get_context()) {
throw sycl::exception(
make_error_code(errc::invalid),
"Cannot submit to a queue with a dependency from a graph that is "
"associated with a different context.");
}

if (EventGraph->getDevice() != MQueue->get_device()) {
throw sycl::exception(
make_error_code(errc::invalid),
"Cannot submit to a queue with a dependency from a graph that is "
"associated with a different device.");
}

if (MQueue->is_in_fusion_mode()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Queue in fusion mode cannot have a dependency from a graph");
}

if (QueueGraph && QueueGraph != EventGraph) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Cannot submit to a recording queue with a "
"dependency from a different graph.");
}

// If the event dependency has a graph, that means that the queue that
// created it was in recording mode. If the current queue is not recording,
// we need to set it to recording (implements the transitive queue recording
// feature).
if (!QueueGraph) {
EventGraph->beginRecording(MQueue);
}
}

if (auto Graph = getCommandGraph(); Graph) {
auto EventGraph = EventImpl->getCommandGraph();
if (EventGraph == nullptr) {
throw sycl::exception(
make_error_code(errc::invalid),
Expand Down
149 changes: 149 additions & 0 deletions sycl/test-e2e/Graph/RecordReplay/transitive_queue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
// Extra run to check for immediate-command-list in Level Zero
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}

// Checks that the transitive queue recording feature is working as expected.
// i.e. submitting a command group function to a queue that has a dependency
// from a graph, should change the state of the queue to recording mode.

#include "../graph_common.hpp"

int main() {
using T = int;

device Dev;
context Ctx{Dev};
queue Q1{Ctx, Dev};
queue Q2{Ctx, Dev};
queue Q3{Ctx, Dev};

const exp_ext::queue_state Recording = exp_ext::queue_state::recording;
const exp_ext::queue_state Executing = exp_ext::queue_state::executing;

auto assertQueueState = [&](exp_ext::queue_state ExpectedQ1,
exp_ext::queue_state ExpectedQ2,
exp_ext::queue_state ExpectedQ3) {
assert(Q1.ext_oneapi_get_state() == ExpectedQ1);
assert(Q2.ext_oneapi_get_state() == ExpectedQ2);
assert(Q3.ext_oneapi_get_state() == ExpectedQ3);
};

std::vector<T> DataA(Size), DataB(Size), DataC(Size);

std::iota(DataA.begin(), DataA.end(), 1);
std::iota(DataB.begin(), DataB.end(), 10);
std::iota(DataC.begin(), DataC.end(), 1000);

std::vector<T> ReferenceA(DataA), ReferenceB(DataB), ReferenceC(DataC);

T *PtrA = malloc_device<T>(Size, Q1);
T *PtrB = malloc_device<T>(Size, Q1);
T *PtrC = malloc_device<T>(Size, Q1);

Q1.copy(DataA.data(), PtrA, Size);
Q1.copy(DataB.data(), PtrB, Size);
Q1.copy(DataC.data(), PtrC, Size);
Q1.wait_and_throw();

exp_ext::command_graph Graph{Q1.get_context(), Q1.get_device()};

Graph.begin_recording(Q1);
assertQueueState(Recording, Executing, Executing);

auto GraphEventA = Q1.submit([&](handler &CGH) {
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrA[Id]++; });
});
assertQueueState(Recording, Executing, Executing);

// Since there is a dependency on GraphEventA which is part of a graph,
// this will change Q2 to the recording state.
auto GraphEventB = Q2.submit([&](handler &CGH) {
CGH.depends_on(GraphEventA);
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrA[Id]++; });
});

// Has no dependencies but should still be recorded to the graph because
// the queue was implicitly changed to recording mode previously.
auto GraphEventC = Q2.submit([&](handler &CGH) {
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrB[Id]++; });
});
assertQueueState(Recording, Recording, Executing);

// Q2 is now in recording mode. Submitting a command group to Q3 with a
// dependency on an event from Q2 should change it to recording mode as well.
auto GraphEventD = Q3.submit([&](handler &CGH) {
CGH.depends_on(GraphEventB);
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrC[Id]++; });
});
assertQueueState(Recording, Recording, Recording);

Graph.end_recording(Q1);
assertQueueState(Executing, Recording, Recording);
Graph.end_recording(Q2);
assertQueueState(Executing, Executing, Recording);

auto GraphEventE = Q1.submit([&](handler &CGH) {
CGH.depends_on(GraphEventD);
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrC[Id]++; });
});
assertQueueState(Recording, Executing, Recording);

Graph.end_recording(Q1);
assertQueueState(Executing, Executing, Recording);

// Q2 is not recording anymore. So this will be submitted outside the graph.
auto OutsideEventA = Q2.submit([&](handler &CGH) {
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrC[Id] /= 2; });
});
assertQueueState(Executing, Executing, Recording);

try {
// Q3 should still be recording. Adding a dependency from an event outside
// the graph should fail.
auto EventF = Q3.submit([&](handler &CGH) {
CGH.depends_on(OutsideEventA);
CGH.parallel_for(range<1>(Size), [=](item<1> Id) { PtrC[Id]++; });
});
} catch (exception &E) {
assert(E.code() == sycl::errc::invalid);
assertQueueState(Executing, Executing, Recording);
}

Q2.wait_and_throw();

Q1.copy(PtrA, DataA.data(), Size);
Q1.copy(PtrB, DataB.data(), Size);
Q1.copy(PtrC, DataC.data(), Size);
Q1.wait_and_throw();

// Check that only DataC was changed before running the graph
for (size_t i = 0; i < Size; i++) {
assert(check_value(i, ReferenceA[i], DataA[i], "DataA"));
assert(check_value(i, ReferenceB[i], DataB[i], "DataB"));
assert(check_value(i, ReferenceC[i] / 2, DataC[i], "DataC"));
}

Graph.end_recording();
assertQueueState(Executing, Executing, Executing);

auto GraphExec = Graph.finalize();

Q1.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
Q1.wait_and_throw();

Q1.copy(PtrA, DataA.data(), Size);
Q1.copy(PtrB, DataB.data(), Size);
Q1.copy(PtrC, DataC.data(), Size);
Q1.wait_and_throw();

for (size_t i = 0; i < Size; i++) {
assert(check_value(i, ReferenceA[i] + 2, DataA[i], "DataA"));
assert(check_value(i, ReferenceB[i] + 1, DataB[i], "DataB"));
assert(check_value(i, ReferenceC[i] / 2 + 2, DataC[i], "DataC"));
}

return 0;
}
47 changes: 47 additions & 0 deletions sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,50 @@ TEST_F(CommandGraphTest, AccessorModeEdges) {
Queue);
testAccessorModeCombo<access_mode::atomic, access_mode::atomic, true>(Queue);
}

// Tests the transitive queue recording behaviour with queue shortcuts.
TEST_F(CommandGraphTest, TransitiveRecordingShortcuts) {
device Dev;
context Ctx{{Dev}};
queue Q1{Ctx, Dev};
queue Q2{Ctx, Dev};
queue Q3{Ctx, Dev};

ext::oneapi::experimental::command_graph Graph1{Q1.get_context(),
Q1.get_device()};

Graph1.begin_recording(Q1);

auto GraphEvent1 = Q1.single_task<class Kernel1>([=] {});
ASSERT_EQ(Q1.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);
ASSERT_EQ(Q2.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);
ASSERT_EQ(Q3.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);

auto GraphEvent2 = Q2.single_task<class Kernel2>(GraphEvent1, [=] {});
ASSERT_EQ(Q1.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);
ASSERT_EQ(Q2.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);
ASSERT_EQ(Q3.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);

auto GraphEvent3 = Q3.parallel_for<class Kernel3>(range<1>{1024}, GraphEvent1,
[=](item<1> Id) {});
ASSERT_EQ(Q1.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);
ASSERT_EQ(Q2.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);
ASSERT_EQ(Q3.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::recording);

Graph1.end_recording();
ASSERT_EQ(Q1.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);
ASSERT_EQ(Q2.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);
ASSERT_EQ(Q3.ext_oneapi_get_state(),
ext::oneapi::experimental::queue_state::executing);
}
Loading

0 comments on commit bdecdd2

Please sign in to comment.