Skip to content

Commit

Permalink
[SYCL][Fusion] Restrict types of fusable command groups (#12556)
Browse files Browse the repository at this point in the history
Only allow command groups of `Kernel` type. Do not add other kind of
command groups to the fusable graph when found, showing a descriptive
warning.

---------

Signed-off-by: Victor Perez <victor.perez@codeplay.com>
  • Loading branch information
victor-eds authored Feb 6, 2024
1 parent 18d6471 commit 2f253a9
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 58 deletions.
2 changes: 1 addition & 1 deletion sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
for (auto &RawCmd : InputKernels) {
auto *KernelCmd = static_cast<ExecCGCommand *>(RawCmd);
auto &CG = KernelCmd->getCG();
assert(CG.getType() == CG::Kernel);
assert(KernelCmd->isFusable());
auto *KernelCG = static_cast<CGExecKernel *>(&CG);

auto KernelName = KernelCG->MKernelName;
Expand Down
32 changes: 29 additions & 3 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ bool Command::isHostTask() const {
CG::CGTYPE::CodeplayHostTask);
}

bool Command::isFusable() const {
return (MType == CommandType::RUN_CG) &&
((static_cast<const ExecCGCommand *>(this))->getCG().getType() ==
CG::CGTYPE::Kernel);
}

static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
const QueueImplPtr &Queue) {
for (auto &EventImpl : EventImpls) {
Expand Down Expand Up @@ -1825,7 +1831,7 @@ void UpdateHostRequirementCommand::emitInstrumentationData() {
#endif
}

static std::string cgTypeToString(detail::CG::CGTYPE Type) {
static std::string_view cgTypeToString(detail::CG::CGTYPE Type) {
switch (Type) {
case detail::CG::Kernel:
return "Kernel";
Expand All @@ -1845,6 +1851,10 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::CopyPtrToAcc:
return "copy ptr to acc";
break;
case detail::CG::Barrier:
return "barrier";
case detail::CG::BarrierWaitlist:
return "barrier waitlist";
case detail::CG::CopyUSM:
return "copy usm";
break;
Expand All @@ -1863,6 +1873,8 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::Fill2DUSM:
return "fill 2d usm";
break;
case detail::CG::AdviseUSM:
return "advise usm";
case detail::CG::Memset2DUSM:
return "memset 2d usm";
break;
Expand All @@ -1872,6 +1884,16 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::CopyFromDeviceGlobal:
return "copy from device_global";
break;
case detail::CG::ReadWriteHostPipe:
return "read_write host pipe";
case detail::CG::ExecCommandBuffer:
return "exec command buffer";
case detail::CG::CopyImage:
return "copy image";
case detail::CG::SemaphoreWait:
return "semaphore wait";
case detail::CG::SemaphoreSignal:
return "semaphore signal";
default:
return "unknown";
break;
Expand Down Expand Up @@ -2102,7 +2124,7 @@ void ExecCGCommand::emitInstrumentationData() {
KernelCG->getKernelName(), MAddress, FromSource);
} break;
default:
KernelName = cgTypeToString(MCommandGroup->getType());
KernelName = getTypeString();
break;
}

Expand Down Expand Up @@ -2150,7 +2172,7 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
break;
}
default:
Stream << "CG type: " << cgTypeToString(MCommandGroup->getType()) << "\\n";
Stream << "CG type: " << getTypeString() << "\\n";
break;
}

Expand All @@ -2165,6 +2187,10 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
}
}

std::string_view ExecCGCommand::getTypeString() const {
return cgTypeToString(MCommandGroup->getType());
}

// SYCL has a parallel_for_work_group variant where the only NDRange
// characteristics set by a user is the number of work groups. This does not
// map to the OpenCL clEnqueueNDRangeAPI, which requires global work size to
Expand Down
3 changes: 3 additions & 0 deletions sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class Command {

bool isHostTask() const;

bool isFusable() const;

protected:
QueueImplPtr MQueue;
EventImplPtr MEvent;
Expand Down Expand Up @@ -648,6 +650,7 @@ class ExecCGCommand : public Command {

void printDot(std::ostream &Stream) const final;
void emitInstrumentationData() final;
std::string_view getTypeString() const;

detail::CG &getCG() const { return *MCommandGroup; }

Expand Down
118 changes: 64 additions & 54 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "detail/config.hpp"
#include <detail/context_impl.hpp>
#include <detail/event_impl.hpp>
#include <sstream>
#include <sycl/feature_test.hpp>
#if SYCL_EXT_CODEPLAY_KERNEL_FUSION
#include <detail/jit_compiler.hpp>
Expand Down Expand Up @@ -949,66 +950,75 @@ Scheduler::GraphBuildResult Scheduler::GraphBuilder::addCG(
if (!NewCmd)
throw runtime_error("Out of host memory", PI_ERROR_OUT_OF_HOST_MEMORY);

// Host tasks cannot participate in fusion. They take the regular route. If
// they create any requirement or event dependency on any of the kernels in
// the fusion list, this will lead to cancellation of the fusion in the
// GraphProcessor.
// Only device kernel command groups can participate in fusion. Otherwise,
// command groups take the regular route. If they create any requirement or
// event dependency on any of the kernels in the fusion list, this will lead
// to cancellation of the fusion in the GraphProcessor.
auto QUniqueID = std::hash<sycl::detail::queue_impl *>()(Queue.get());
if (isInFusionMode(QUniqueID) && !NewCmd->isHostTask()) {
auto *FusionCmd = findFusionList(QUniqueID)->second.get();

bool dependsOnFusion = false;
for (auto Ev = Events.begin(); Ev != Events.end();) {
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
if (!EvDepCmd) {
continue;
}
// Handle event dependencies on any commands part of another active
// fusion.
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
printFusionWarning("Aborting fusion because of event dependency from a "
"different fusion");
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
}
// Check if this command depends on the placeholder command for the fusion
// itself participates in.
if (EvDepCmd == FusionCmd) {
Ev = Events.erase(Ev);
dependsOnFusion = true;
} else {
++Ev;
if (isInFusionMode(QUniqueID)) {
if (NewCmd->isFusable()) {
auto *FusionCmd = findFusionList(QUniqueID)->second.get();

bool dependsOnFusion = false;
for (auto Ev = Events.begin(); Ev != Events.end();) {
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
if (!EvDepCmd) {
continue;
}
// Handle event dependencies on any commands part of another active
// fusion.
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
printFusionWarning(
"Aborting fusion because of event dependency from a "
"different fusion");
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
}
// Check if this command depends on the placeholder command for the
// fusion itself participates in.
if (EvDepCmd == FusionCmd) {
Ev = Events.erase(Ev);
dependsOnFusion = true;
} else {
++Ev;
}
}
}

// If this command has an explicit event dependency on the placeholder
// command for this fusion (because it used depends_on on the event returned
// by submitting another kernel to this fusion earlier), add a dependency on
// all the commands in the fusion list so far.
if (dependsOnFusion) {
for (auto *Cmd : FusionCmd->getFusionList()) {
Events.push_back(Cmd->getEvent());
// If this command has an explicit event dependency on the placeholder
// command for this fusion (because it used depends_on on the event
// returned by submitting another kernel to this fusion earlier), add a
// dependency on all the commands in the fusion list so far.
if (dependsOnFusion) {
for (auto *Cmd : FusionCmd->getFusionList()) {
Events.push_back(Cmd->getEvent());
}
}
}

// Add the kernel to the graph, but delay the enqueue of any auxiliary
// commands (e.g., allocations) resulting from that process by adding them
// to the list of auxiliary commands of the fusion command.
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
FusionCmd->auxiliaryCommands());

// Set the fusion command, so we recognize when another command depends on a
// kernel in the fusion list.
FusionCmd->addToFusionList(NewCmd.get());
NewCmd->MFusionCmd = FusionCmd;
std::vector<Command *> ToCleanUp;
// Add an event dependency from the fusion placeholder command to the new
// kernel.
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
if (ConnectionCmd) {
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
// Add the kernel to the graph, but delay the enqueue of any auxiliary
// commands (e.g., allocations) resulting from that process by adding them
// to the list of auxiliary commands of the fusion command.
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events,
Queue, FusionCmd->auxiliaryCommands());

// Set the fusion command, so we recognize when another command depends on
// a kernel in the fusion list.
FusionCmd->addToFusionList(NewCmd.get());
NewCmd->MFusionCmd = FusionCmd;
std::vector<Command *> ToCleanUp;
// Add an event dependency from the fusion placeholder command to the new
// kernel.
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
if (ConnectionCmd) {
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
}
return {NewCmd.release(), FusionCmd->getEvent(), false};
} else {
std::string s;
std::stringstream ss(s);
ss << "Not fusing '" << NewCmd->getTypeString()
<< "' command group. Can only fuse device kernel command groups.";
printFusionWarning(ss.str());
}
return {NewCmd.release(), FusionCmd->getEvent(), false};
}
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
Expand Down
109 changes: 109 additions & 0 deletions sycl/test-e2e/KernelFusion/non-kernel-cg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// RUN: %{build} -fsycl-embed-ir -o %t.out
// RUN: env SYCL_RT_WARNING_LEVEL=2 %{run} %t.out 2>&1 | FileCheck %s

// Test non-kernel device command groups are not fused

#include <sycl/sycl.hpp>

using namespace sycl;

int main() {
constexpr size_t dataSize = 512;
constexpr float Pattern{10};

queue q{ext::codeplay::experimental::property::queue::enable_fusion{}};
ext::codeplay::experimental::fusion_wrapper fw(q);

constexpr size_t count = 64;
auto *dst = malloc_device<float>(count, q);
auto *src = malloc_device<float>(count, q);

{
// CHECK: Not fusing 'copy acc to ptr' command group. Can only fuse device kernel command groups.
buffer<float> src(dataSize);
std::shared_ptr<float> dst(new float[dataSize]);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(src, cgh, read_only);
cgh.copy(acc, dst);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy ptr to acc' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
std::shared_ptr<float> src(new float[dataSize]);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(dst, cgh, write_only);
cgh.copy(src, acc);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy acc to acc' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
buffer<float> src(dataSize);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc0(src, cgh, read_only);
accessor acc1(dst, cgh, write_only);
cgh.copy(acc0, acc1);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'barrier' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'barrier waitlist' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
buffer<float> src(dataSize);
std::vector<event> event_list;
event_list.push_back(q.submit([&](handler &cgh) {
accessor acc0(src, cgh, read_only);
accessor acc1(dst, cgh, write_only);
cgh.copy(acc0, acc1);
}));
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(event_list); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'fill' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(dst, cgh, write_only);
cgh.fill(acc, Pattern);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy usm' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.memcpy(dst, src, count); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'fill usm' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) {
cgh.memset(dst, static_cast<int>(Pattern), count);
});
fw.complete_fusion();
}

free(src, q);
free(dst, q);
}

0 comments on commit 2f253a9

Please sign in to comment.