Skip to content

Commit

Permalink
[SYCL][Graph] Implementation of whole graph update
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiomestre committed Mar 31, 2024
1 parent cefbadd commit 6e98293
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 117 deletions.
48 changes: 43 additions & 5 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,9 @@ void exec_graph_impl::createCommandBuffers(
exec_graph_impl::exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl,
const property_list &PropList)
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context),
MRequirements(), MExecutionEvents(),
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(),
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()) {

// If the graph has been marked as updatable then check if the backend
Expand Down Expand Up @@ -1155,9 +1156,48 @@ void exec_graph_impl::duplicateNodes() {
MNodeStorage.insert(MNodeStorage.begin(), NewNodes.begin(), NewNodes.end());
}

void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {

if (MDevice != GraphImpl->getDevice()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"The graphs must have been created with matching devices.");
}
if (MContext != GraphImpl->getContext()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"The graphs must have been created with matching contexts.");
}

if (MNodeStorage.size() != GraphImpl->MNodeStorage.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Mismatch found in the number of nodes. The graphs "
"must have a matching topology.");
} else {
for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
if (MNodeStorage[i]->MSuccessors.size() !=
GraphImpl->MNodeStorage[i]->MSuccessors.size() ||
MNodeStorage[i]->MPredecessors.size() !=
GraphImpl->MNodeStorage[i]->MPredecessors.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Mismatch found in the number of edges. The "
"graphs must have a matching topology.");
}
}
}

for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
MIDCache.insert(
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i]));
}

update(GraphImpl->MNodeStorage);
}

void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
}

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> Nodes) {

Expand Down Expand Up @@ -1598,9 +1638,7 @@ void executable_command_graph::finalizeImpl() {

void executable_command_graph::update(
const command_graph<graph_state::modifiable> &Graph) {
(void)Graph;
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Method not yet implemented");
impl->update(sycl::detail::getSyclObjImpl(Graph));
}

void executable_command_graph::update(const node &Node) {
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,10 @@ class exec_graph_impl {
void createCommandBuffers(sycl::device Device,
std::shared_ptr<partition> &Partition);

/// Query for the device tied to this graph.
/// @return Device associated with graph.
sycl::device getDevice() const { return MDevice; }

/// Query for the context tied to this graph.
/// @return Context associated with graph.
sycl::context getContext() const { return MContext; }
Expand Down Expand Up @@ -1320,6 +1324,7 @@ class exec_graph_impl {
return MRequirements;
}

void update(std::shared_ptr<graph_impl> GraphImpl);
void update(std::shared_ptr<node_impl> Node);
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);

Expand Down Expand Up @@ -1408,6 +1413,8 @@ class exec_graph_impl {
/// Map of nodes in the exec graph to the partition number to which they
/// belong.
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
/// Device associated with this executable graph.
sycl::device MDevice;
/// Context associated with this executable graph.
sycl::context MContext;
/// List of requirements for enqueueing this command graph, accumulated from
Expand Down
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/double_buffer.cpp"
#include "../Update/whole_update_double_buffer.cpp"
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/executable_graph_update.cpp"
#include "../Update/whole_update_usm.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/executable_graph_update_ordering.cpp"
#include "../Update/whole_update_delay.cpp"
104 changes: 0 additions & 104 deletions sycl/test-e2e/Graph/Inputs/double_buffer.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/double_buffer.cpp"
#include "../Update/whole_update_double_buffer.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/executable_graph_update.cpp"
#include "../Update/whole_update_usm.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/executable_graph_update_ordering.cpp"
#include "../Update/whole_update_delay.cpp"
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Tests executable graph update by introducing a delay in to the update
// transactions dependencies to check correctness of behaviour.

// TODO This test is disabled because host-tasks are not supported for graph
// updates yet.
#include "../graph_common.hpp"
#include <thread>

Expand Down
107 changes: 107 additions & 0 deletions sycl/test-e2e/Graph/Update/whole_update_double_buffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
// Extra run to check for immediate-command-list in Level Zero
// RUN: %if level_zero && linux %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
//
// UNSUPPORTED: opencl, level_zero

// Tests executable graph update by creating a double buffering scenario, where
// a single graph is repeatedly executed then updated to swap between two sets
// of buffers.
#define GRAPH_E2E_EXPLICIT

#include "../graph_common.hpp"

int main() {
queue Queue{};

using T = int;

std::vector<T> DataA(Size), DataB(Size), DataC(Size);
std::vector<T> DataA2(Size), DataB2(Size), DataC2(Size);

std::iota(DataA.begin(), DataA.end(), 1);
std::iota(DataB.begin(), DataB.end(), 10);
std::iota(DataC.begin(), DataC.end(), 1000);

std::iota(DataA2.begin(), DataA2.end(), 3);
std::iota(DataB2.begin(), DataB2.end(), 13);
std::iota(DataC2.begin(), DataC2.end(), 1333);

std::vector<T> ReferenceA(DataA), ReferenceB(DataB), ReferenceC(DataC);
std::vector<T> ReferenceA2(DataA2), ReferenceB2(DataB2), ReferenceC2(DataC2);

calculate_reference_data(Iterations, Size, ReferenceA, ReferenceB,
ReferenceC);
calculate_reference_data(Iterations, Size, ReferenceA2, ReferenceB2,
ReferenceC2);

buffer<T> BufferA{DataA};
buffer<T> BufferB{DataB};
buffer<T> BufferC{DataC};

buffer<T> BufferA2{DataA2};
buffer<T> BufferB2{DataB2};
buffer<T> BufferC2{DataC2};

BufferA.set_write_back(false);
BufferB.set_write_back(false);
BufferC.set_write_back(false);
BufferA2.set_write_back(false);
BufferB2.set_write_back(false);
BufferC2.set_write_back(false);

Queue.wait_and_throw();
{
exp_ext::command_graph Graph{
Queue.get_context(), Queue.get_device(),
exp_ext::property::graph::assume_buffer_outlives_graph{}};
add_nodes(Graph, Queue, Size, BufferA, BufferB, BufferC);

auto ExecGraph = Graph.finalize(exp_ext::property::graph::updatable{});

// Create second graph using other buffer set
exp_ext::command_graph GraphUpdate{
Queue.get_context(), Queue.get_device(),
exp_ext::property::graph::assume_buffer_outlives_graph{}};
add_nodes(GraphUpdate, Queue, Size, BufferA2, BufferB2, BufferC2);

event Event;
for (size_t i = 0; i < Iterations; i++) {
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(ExecGraph);
});
// Update to second set of buffers
ExecGraph.update(GraphUpdate);
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(ExecGraph);
});
// Reset back to original buffers
ExecGraph.update(Graph);
}

Queue.wait_and_throw();
}
host_accessor HostDataA(BufferA);
host_accessor HostDataB(BufferB);
host_accessor HostDataC(BufferC);
host_accessor HostDataA2(BufferA2);
host_accessor HostDataB2(BufferB2);
host_accessor HostDataC2(BufferC2);

for (size_t i = 0; i < Size; i++) {
assert(check_value(i, ReferenceA[i], HostDataA[i], "DataA"));
assert(check_value(i, ReferenceB[i], HostDataB[i], "DataB"));
assert(check_value(i, ReferenceC[i], HostDataC[i], "DataC"));

assert(check_value(i, ReferenceA2[i], HostDataA2[i], "DataA2"));
assert(check_value(i, ReferenceB2[i], HostDataB2[i], "DataB2"));
assert(check_value(i, ReferenceC2[i], HostDataC2[i], "DataC2"));
}

return 0;
}
Loading

0 comments on commit 6e98293

Please sign in to comment.