Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Graph] Implementation of whole graph update #366

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,9 @@ void exec_graph_impl::createCommandBuffers(
exec_graph_impl::exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl,
const property_list &PropList)
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context),
MRequirements(), MExecutionEvents(),
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(),
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()) {

// If the graph has been marked as updatable then check if the backend
Expand Down Expand Up @@ -1155,9 +1156,48 @@ void exec_graph_impl::duplicateNodes() {
MNodeStorage.insert(MNodeStorage.begin(), NewNodes.begin(), NewNodes.end());
}

void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {

if (MDevice != GraphImpl->getDevice()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"The graphs must have been created with matching devices.");
Bensuo marked this conversation as resolved.
Show resolved Hide resolved
}
if (MContext != GraphImpl->getContext()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"The graphs must have been created with matching contexts.");
}

if (MNodeStorage.size() != GraphImpl->MNodeStorage.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Mismatch found in the number of nodes. The graphs "
"must have a matching topology.");
} else {
for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
if (MNodeStorage[i]->MSuccessors.size() !=
GraphImpl->MNodeStorage[i]->MSuccessors.size() ||
MNodeStorage[i]->MPredecessors.size() !=
GraphImpl->MNodeStorage[i]->MPredecessors.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Mismatch found in the number of edges. The "
"graphs must have a matching topology.");
EwanC marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nit as before I think these error messages could be more explicit that this is the update operation.

}
}
}

for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
MIDCache.insert(
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i]));
}

update(GraphImpl->MNodeStorage);
}

void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
}

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> Nodes) {

Expand Down Expand Up @@ -1598,9 +1638,7 @@ void executable_command_graph::finalizeImpl() {

void executable_command_graph::update(
const command_graph<graph_state::modifiable> &Graph) {
(void)Graph;
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Method not yet implemented");
impl->update(sycl::detail::getSyclObjImpl(Graph));
}

void executable_command_graph::update(const node &Node) {
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,10 @@ class exec_graph_impl {
void createCommandBuffers(sycl::device Device,
std::shared_ptr<partition> &Partition);

/// Query for the device tied to this graph.
/// @return Device associated with graph.
sycl::device getDevice() const { return MDevice; }

/// Query for the context tied to this graph.
/// @return Context associated with graph.
sycl::context getContext() const { return MContext; }
Expand Down Expand Up @@ -1320,6 +1324,7 @@ class exec_graph_impl {
return MRequirements;
}

void update(std::shared_ptr<graph_impl> GraphImpl);
void update(std::shared_ptr<node_impl> Node);
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);

Expand Down Expand Up @@ -1408,6 +1413,8 @@ class exec_graph_impl {
/// Map of nodes in the exec graph to the partition number to which they
/// belong.
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
/// Device associated with this executable graph.
sycl::device MDevice;
/// Context associated with this executable graph.
sycl::context MContext;
/// List of requirements for enqueueing this command graph, accumulated from
Expand Down
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/double_buffer.cpp"
#include "../Update/whole_update_double_buffer.cpp"
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/executable_graph_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/executable_graph_update.cpp"
#include "../Update/whole_update_usm.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

#define GRAPH_E2E_EXPLICIT

#include "../Inputs/executable_graph_update_ordering.cpp"
#include "../Update/whole_update_delay.cpp"
104 changes: 0 additions & 104 deletions sycl/test-e2e/Graph/Inputs/double_buffer.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/RecordReplay/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/double_buffer.cpp"
#include "../Update/whole_update_double_buffer.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/executable_graph_update.cpp"
#include "../Update/whole_update_usm.cpp"
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/executable_graph_update_ordering.cpp"
#include "../Update/whole_update_delay.cpp"
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Tests executable graph update by introducing a delay in to the update
// transactions dependencies to check correctness of behaviour.

// TODO This test is disabled because host-tasks are not supported for graph
// updates yet.
Bensuo marked this conversation as resolved.
Show resolved Hide resolved
#include "../graph_common.hpp"
#include <thread>

Expand Down
107 changes: 107 additions & 0 deletions sycl/test-e2e/Graph/Update/whole_update_double_buffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
// Extra run to check for immediate-command-list in Level Zero
// RUN: %if level_zero && linux %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
//
// UNSUPPORTED: opencl, level_zero

// Tests executable graph update by creating a double buffering scenario, where
// a single graph is repeatedly executed then updated to swap between two sets
// of buffers.
#define GRAPH_E2E_EXPLICIT
EwanC marked this conversation as resolved.
Show resolved Hide resolved

#include "../graph_common.hpp"

int main() {
queue Queue{};

using T = int;

std::vector<T> DataA(Size), DataB(Size), DataC(Size);
std::vector<T> DataA2(Size), DataB2(Size), DataC2(Size);

std::iota(DataA.begin(), DataA.end(), 1);
std::iota(DataB.begin(), DataB.end(), 10);
std::iota(DataC.begin(), DataC.end(), 1000);

std::iota(DataA2.begin(), DataA2.end(), 3);
std::iota(DataB2.begin(), DataB2.end(), 13);
std::iota(DataC2.begin(), DataC2.end(), 1333);

std::vector<T> ReferenceA(DataA), ReferenceB(DataB), ReferenceC(DataC);
std::vector<T> ReferenceA2(DataA2), ReferenceB2(DataB2), ReferenceC2(DataC2);

calculate_reference_data(Iterations, Size, ReferenceA, ReferenceB,
ReferenceC);
calculate_reference_data(Iterations, Size, ReferenceA2, ReferenceB2,
ReferenceC2);

buffer<T> BufferA{DataA};
buffer<T> BufferB{DataB};
buffer<T> BufferC{DataC};

buffer<T> BufferA2{DataA2};
buffer<T> BufferB2{DataB2};
buffer<T> BufferC2{DataC2};

BufferA.set_write_back(false);
BufferB.set_write_back(false);
BufferC.set_write_back(false);
BufferA2.set_write_back(false);
BufferB2.set_write_back(false);
BufferC2.set_write_back(false);

Queue.wait_and_throw();
{
exp_ext::command_graph Graph{
Queue.get_context(), Queue.get_device(),
exp_ext::property::graph::assume_buffer_outlives_graph{}};
add_nodes(Graph, Queue, Size, BufferA, BufferB, BufferC);

auto ExecGraph = Graph.finalize(exp_ext::property::graph::updatable{});

// Create second graph using other buffer set
exp_ext::command_graph GraphUpdate{
Queue.get_context(), Queue.get_device(),
exp_ext::property::graph::assume_buffer_outlives_graph{}};
add_nodes(GraphUpdate, Queue, Size, BufferA2, BufferB2, BufferC2);

event Event;
for (size_t i = 0; i < Iterations; i++) {
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(ExecGraph);
});
// Update to second set of buffers
ExecGraph.update(GraphUpdate);
Event = Queue.submit([&](handler &CGH) {
CGH.depends_on(Event);
CGH.ext_oneapi_graph(ExecGraph);
});
// Reset back to original buffers
ExecGraph.update(Graph);
}

Queue.wait_and_throw();
}
host_accessor HostDataA(BufferA);
host_accessor HostDataB(BufferB);
host_accessor HostDataC(BufferC);
host_accessor HostDataA2(BufferA2);
host_accessor HostDataB2(BufferB2);
host_accessor HostDataC2(BufferC2);

for (size_t i = 0; i < Size; i++) {
assert(check_value(i, ReferenceA[i], HostDataA[i], "DataA"));
assert(check_value(i, ReferenceB[i], HostDataB[i], "DataB"));
assert(check_value(i, ReferenceC[i], HostDataC[i], "DataC"));

assert(check_value(i, ReferenceA2[i], HostDataA2[i], "DataA2"));
assert(check_value(i, ReferenceB2[i], HostDataB2[i], "DataB2"));
assert(check_value(i, ReferenceC2[i], HostDataC2[i], "DataC2"));
}

return 0;
}
Loading
Loading