diff --git a/sycl/source/detail/jit_compiler.cpp b/sycl/source/detail/jit_compiler.cpp index 7a452115f5a94..2c177a8faf76d 100644 --- a/sycl/source/detail/jit_compiler.cpp +++ b/sycl/source/detail/jit_compiler.cpp @@ -672,7 +672,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue, for (auto &RawCmd : InputKernels) { auto *KernelCmd = static_cast(RawCmd); auto &CG = KernelCmd->getCG(); - assert(CG.getType() == CG::Kernel); + assert(KernelCmd->isFusable()); auto *KernelCG = static_cast(&CG); auto KernelName = KernelCG->MKernelName; diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 455a8353ce3f0..23b6eec33886e 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -299,6 +299,12 @@ bool Command::isHostTask() const { CG::CGTYPE::CodeplayHostTask); } +bool Command::isFusable() const { + return (MType == CommandType::RUN_CG) && + ((static_cast(this))->getCG().getType() == + CG::CGTYPE::Kernel); +} + static void flushCrossQueueDeps(const std::vector &EventImpls, const QueueImplPtr &Queue) { for (auto &EventImpl : EventImpls) { @@ -1825,7 +1831,7 @@ void UpdateHostRequirementCommand::emitInstrumentationData() { #endif } -static std::string cgTypeToString(detail::CG::CGTYPE Type) { +static std::string_view cgTypeToString(detail::CG::CGTYPE Type) { switch (Type) { case detail::CG::Kernel: return "Kernel"; @@ -1845,6 +1851,10 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) { case detail::CG::CopyPtrToAcc: return "copy ptr to acc"; break; + case detail::CG::Barrier: + return "barrier"; + case detail::CG::BarrierWaitlist: + return "barrier waitlist"; case detail::CG::CopyUSM: return "copy usm"; break; @@ -1863,6 +1873,8 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) { case detail::CG::Fill2DUSM: return "fill 2d usm"; break; + case detail::CG::AdviseUSM: + return "advise usm"; case detail::CG::Memset2DUSM: return "memset 2d usm"; break; @@ -1872,6 +1884,16 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) { case detail::CG::CopyFromDeviceGlobal: return "copy from device_global"; break; + case detail::CG::ReadWriteHostPipe: + return "read_write host pipe"; + case detail::CG::ExecCommandBuffer: + return "exec command buffer"; + case detail::CG::CopyImage: + return "copy image"; + case detail::CG::SemaphoreWait: + return "semaphore wait"; + case detail::CG::SemaphoreSignal: + return "semaphore signal"; default: return "unknown"; break; @@ -2102,7 +2124,7 @@ void ExecCGCommand::emitInstrumentationData() { KernelCG->getKernelName(), MAddress, FromSource); } break; default: - KernelName = cgTypeToString(MCommandGroup->getType()); + KernelName = getTypeString(); break; } @@ -2150,7 +2172,7 @@ void ExecCGCommand::printDot(std::ostream &Stream) const { break; } default: - Stream << "CG type: " << cgTypeToString(MCommandGroup->getType()) << "\\n"; + Stream << "CG type: " << getTypeString() << "\\n"; break; } @@ -2165,6 +2187,10 @@ void ExecCGCommand::printDot(std::ostream &Stream) const { } } +std::string_view ExecCGCommand::getTypeString() const { + return cgTypeToString(MCommandGroup->getType()); +} + // SYCL has a parallel_for_work_group variant where the only NDRange // characteristics set by a user is the number of work groups. This does not // map to the OpenCL clEnqueueNDRangeAPI, which requires global work size to diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index b8983be03d808..7898e3f65b812 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -244,6 +244,8 @@ class Command { bool isHostTask() const; + bool isFusable() const; + protected: QueueImplPtr MQueue; EventImplPtr MEvent; @@ -648,6 +650,7 @@ class ExecCGCommand : public Command { void printDot(std::ostream &Stream) const final; void emitInstrumentationData() final; + std::string_view getTypeString() const; detail::CG &getCG() const { return *MCommandGroup; } diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index a1d2ac5a5045b..66823c088463a 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -9,6 +9,7 @@ #include "detail/config.hpp" #include #include +#include #include #if SYCL_EXT_CODEPLAY_KERNEL_FUSION #include @@ -954,61 +955,70 @@ Scheduler::GraphBuildResult Scheduler::GraphBuilder::addCG( // the fusion list, this will lead to cancellation of the fusion in the // GraphProcessor. auto QUniqueID = std::hash()(Queue.get()); - if (isInFusionMode(QUniqueID) && !NewCmd->isHostTask()) { - auto *FusionCmd = findFusionList(QUniqueID)->second.get(); - - bool dependsOnFusion = false; - for (auto Ev = Events.begin(); Ev != Events.end();) { - auto *EvDepCmd = static_cast((*Ev)->getCommand()); - if (!EvDepCmd) { - continue; - } - // Handle event dependencies on any commands part of another active - // fusion. - if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) { - printFusionWarning("Aborting fusion because of event dependency from a " - "different fusion"); - cancelFusion(EvDepCmd->getQueue(), ToEnqueue); - } - // Check if this command depends on the placeholder command for the fusion - // itself participates in. - if (EvDepCmd == FusionCmd) { - Ev = Events.erase(Ev); - dependsOnFusion = true; - } else { - ++Ev; + if (isInFusionMode(QUniqueID)) { + if (NewCmd->isFusable()) { + auto *FusionCmd = findFusionList(QUniqueID)->second.get(); + + bool dependsOnFusion = false; + for (auto Ev = Events.begin(); Ev != Events.end();) { + auto *EvDepCmd = static_cast((*Ev)->getCommand()); + if (!EvDepCmd) { + continue; + } + // Handle event dependencies on any commands part of another active + // fusion. + if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) { + printFusionWarning( + "Aborting fusion because of event dependency from a " + "different fusion"); + cancelFusion(EvDepCmd->getQueue(), ToEnqueue); + } + // Check if this command depends on the placeholder command for the + // fusion itself participates in. + if (EvDepCmd == FusionCmd) { + Ev = Events.erase(Ev); + dependsOnFusion = true; + } else { + ++Ev; + } } - } - // If this command has an explicit event dependency on the placeholder - // command for this fusion (because it used depends_on on the event returned - // by submitting another kernel to this fusion earlier), add a dependency on - // all the commands in the fusion list so far. - if (dependsOnFusion) { - for (auto *Cmd : FusionCmd->getFusionList()) { - Events.push_back(Cmd->getEvent()); + // If this command has an explicit event dependency on the placeholder + // command for this fusion (because it used depends_on on the event + // returned by submitting another kernel to this fusion earlier), add a + // dependency on all the commands in the fusion list so far. + if (dependsOnFusion) { + for (auto *Cmd : FusionCmd->getFusionList()) { + Events.push_back(Cmd->getEvent()); + } } - } - // Add the kernel to the graph, but delay the enqueue of any auxiliary - // commands (e.g., allocations) resulting from that process by adding them - // to the list of auxiliary commands of the fusion command. - createGraphForCommand(NewCmd.get(), NewCmd->getCG(), - isInteropHostTask(NewCmd.get()), Reqs, Events, Queue, - FusionCmd->auxiliaryCommands()); - - // Set the fusion command, so we recognize when another command depends on a - // kernel in the fusion list. - FusionCmd->addToFusionList(NewCmd.get()); - NewCmd->MFusionCmd = FusionCmd; - std::vector ToCleanUp; - // Add an event dependency from the fusion placeholder command to the new - // kernel. - auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp); - if (ConnectionCmd) { - FusionCmd->auxiliaryCommands().push_back(ConnectionCmd); + // Add the kernel to the graph, but delay the enqueue of any auxiliary + // commands (e.g., allocations) resulting from that process by adding them + // to the list of auxiliary commands of the fusion command. + createGraphForCommand(NewCmd.get(), NewCmd->getCG(), + isInteropHostTask(NewCmd.get()), Reqs, Events, + Queue, FusionCmd->auxiliaryCommands()); + + // Set the fusion command, so we recognize when another command depends on + // a kernel in the fusion list. + FusionCmd->addToFusionList(NewCmd.get()); + NewCmd->MFusionCmd = FusionCmd; + std::vector ToCleanUp; + // Add an event dependency from the fusion placeholder command to the new + // kernel. + auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp); + if (ConnectionCmd) { + FusionCmd->auxiliaryCommands().push_back(ConnectionCmd); + } + return {NewCmd.release(), FusionCmd->getEvent(), false}; + } else { + std::string s; + std::stringstream ss(s); + ss << "Not fusing '" << NewCmd->getTypeString() + << "' command group. Can only fuse device kernel command groups"; + printFusionWarning(ss.str()); } - return {NewCmd.release(), FusionCmd->getEvent(), false}; } createGraphForCommand(NewCmd.get(), NewCmd->getCG(), isInteropHostTask(NewCmd.get()), Reqs, Events, Queue, diff --git a/sycl/test-e2e/KernelFusion/non-kernel-cg.cpp b/sycl/test-e2e/KernelFusion/non-kernel-cg.cpp new file mode 100644 index 0000000000000..aa796dcde291b --- /dev/null +++ b/sycl/test-e2e/KernelFusion/non-kernel-cg.cpp @@ -0,0 +1,115 @@ +// RUN: %{build} -fsycl-embed-ir -o %t.out +// RUN: env SYCL_RT_WARNING_LEVEL=2 %{run} %t.out + +// Test fusion of non-kernel device command groups is aborted. + +#include "sycl/detail/pi.h" +#include + +using namespace sycl; + +int main() { + constexpr size_t dataSize = 512; + constexpr float Pattern{10}; + + queue q{ext::codeplay::experimental::property::queue::enable_fusion{}}; + ext::codeplay::experimental::fusion_wrapper fw(q); + + constexpr size_t count = 64; + auto *dst = malloc_device(count, q); + auto *src = malloc_device(count, q); + + { + // CHECK: Not fusing 'copy acc to ptr' command group. Can only fuse device kernel command groups + buffer src(dataSize); + std::shared_ptr dst(new float[dataSize]); + fw.start_fusion(); + q.submit([&](handler &cgh) { + accessor acc(src, cgh, read_only); + cgh.copy(acc, dst); + }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'copy ptr to acc' command group. Can only fuse device kernel command groups + buffer dst(dataSize); + std::shared_ptr src(new float[dataSize]); + fw.start_fusion(); + q.submit([&](handler &cgh) { + accessor acc(dst, cgh, write_only); + cgh.copy(src, acc); + }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'copy acc to acc' command group. Can only fuse device kernel command groups + buffer dst(dataSize); + buffer src(dataSize); + fw.start_fusion(); + q.submit([&](handler &cgh) { + accessor acc0(src, cgh, read_only); + accessor acc1(dst, cgh, write_only); + cgh.copy(acc0, acc1); + }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'barrier' command group. Can only fuse device kernel command groups + fw.start_fusion(); + q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(); }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'barrier waitlist' command group. Can only fuse device kernel command groups + buffer dst(dataSize); + buffer src(dataSize); + std::vector event_list; + event_list.push_back(q.submit([&](handler &cgh) { + accessor acc0(src, cgh, read_only); + accessor acc1(dst, cgh, write_only); + cgh.copy(acc0, acc1); + })); + fw.start_fusion(); + q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(event_list); }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'fill' command group. Can only fuse device kernel command groups + buffer dst(dataSize); + fw.start_fusion(); + q.submit([&](handler &cgh) { + accessor acc(dst, cgh, write_only); + cgh.fill(acc, Pattern); + }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'copy usm' command group. Can only fuse device kernel command groups + fw.start_fusion(); + q.submit([&](handler &cgh) { cgh.memcpy(dst, src, count); }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'fill usm' command group. Can only fuse device kernel command groups + fw.start_fusion(); + q.submit([&](handler &cgh) { cgh.fill(dst, Pattern, count); }); + fw.complete_fusion(); + } + + { + // CHECK: Not fusing 'prefetch usm' command group. Can only fuse device kernel command groups + fw.start_fusion(); + q.submit([&](handler &cgh) { cgh.prefetch(dst, count); }); + fw.complete_fusion(); + } + + free(src, q); + free(dst, q); +}