diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 3268a27fbb827..d2fd42a678409 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1198,6 +1198,28 @@ void exec_graph_impl::update(std::shared_ptr GraphImpl) { "Cannot update using a graph with mismatched node types. Each pair " "of nodes being updated must have the same type"); } + + if (const auto &CG = MNodeStorage[i]->MCommandGroup; CG) { + sycl::detail::CGExecKernel *TargetCGExec = + static_cast(CG.get()); + const std::string &TargetKernelName = TargetCGExec->getKernelName(); + + sycl::detail::CGExecKernel *SourceCGExec = + static_cast( + GraphImpl->MNodeStorage[i]->MCommandGroup.get()); + const std::string &SourceKernelName = SourceCGExec->getKernelName(); + + if (TargetKernelName.compare(SourceKernelName) != 0) { + std::stringstream ErrorStream( + "Cannot update using a graph with mismatched kernel " + "types. Source node type "); + ErrorStream << SourceKernelName; + ErrorStream << ", target node type "; + ErrorStream << TargetKernelName; + throw sycl::exception(sycl::make_error_code(errc::invalid), + ErrorStream.str()); + } + } } } diff --git a/sycl/test-e2e/Graph/Update/whole_update_kernel_type_mismatch.cpp b/sycl/test-e2e/Graph/Update/whole_update_kernel_type_mismatch.cpp new file mode 100644 index 0000000000000..b077d8dc82af9 --- /dev/null +++ b/sycl/test-e2e/Graph/Update/whole_update_kernel_type_mismatch.cpp @@ -0,0 +1,108 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG +// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} +// Extra run to check for immediate-command-list in Level Zero +// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} + +#include "../graph_common.hpp" + +void testFunctors(queue Queue, int *Data) { + exp_ext::command_graph Graph{Queue}; + exp_ext::command_graph UpdateGraph{Queue}; + struct KernelFunctorA { + KernelFunctorA(int *Data) : Data(Data) {} + + void operator()() const { Data[0] = 5; } + + int *Data; + }; + + struct KernelFunctorB { + KernelFunctorB(int *Data) : Data(Data) {} + void operator()() const { Data[0] = 5; } + + int *Data; + }; + + Graph.add([&](handler &CGH) { CGH.single_task(KernelFunctorA{Data}); }); + + UpdateGraph.add([&](handler &CGH) { CGH.single_task(KernelFunctorB{Data}); }); + + auto GraphExec = Graph.finalize(exp_ext::property::graph::updatable{}); + + // Check it's an error if kernel types don't match + std::error_code ErrorCode = make_error_code(sycl::errc::success); + try { + GraphExec.update(UpdateGraph); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); +} + +void testUnNamedLambdas(queue Queue, int *Data) { + exp_ext::command_graph Graph{Queue}; + exp_ext::command_graph UpdateGraph{Queue}; + + Graph.add([&](handler &CGH) { CGH.single_task([=]() { Data[0] = 4; }); }); + + UpdateGraph.add( + [&](handler &CGH) { CGH.single_task([=]() { Data[0] = 5; }); }); + + auto GraphExec = Graph.finalize(exp_ext::property::graph::updatable{}); + + // Check it's an error if kernel types don't match + std::error_code ErrorCode = make_error_code(sycl::errc::success); + try { + GraphExec.update(UpdateGraph); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); +} +void testNamedLambdas(queue Queue, int *Data) { + exp_ext::command_graph Graph{Queue}; + exp_ext::command_graph UpdateGraph{Queue}; + + auto LambdaA = [=]() { + for (int i = 0; i < Size; i++) { + Data[i] = i; + } + }; + + Graph.add([&](handler &CGH) { CGH.single_task(LambdaA); }); + + auto LambdaB = [=]() { + for (int i = 0; i < Size; i++) { + Data[i] = i * 2; + } + }; + + UpdateGraph.add( + [&](handler &CGH) { CGH.single_task(LambdaB); }); + + auto GraphExec = Graph.finalize(exp_ext::property::graph::updatable{}); + + // Check it's an error if kernel types don't match + std::error_code ErrorCode = make_error_code(sycl::errc::success); + try { + GraphExec.update(UpdateGraph); + } catch (const sycl::exception &e) { + ErrorCode = e.code(); + } + assert(ErrorCode == sycl::errc::invalid); +} + +int main() { + queue Queue{}; + int *Data = malloc_device(Size, Queue); + + testNamedLambdas(Queue, Data); + testUnNamedLambdas(Queue, Data); + testFunctors(Queue, Data); + + sycl::free(Data, Queue); + + return 0; +}