Skip to content

Commit

Permalink
[SYCL] Mark parallel_for_work_item even when called indirectly (#12805)
Browse files Browse the repository at this point in the history
Previously, we would mark parallel_for_work_item FunctionDecl only
when called directly from a parallel_for_work_group region.  This
change marks it when called even indirectly.
  • Loading branch information
premanandrao authored Mar 4, 2024
1 parent ea400f7 commit 19bb017
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
44 changes: 38 additions & 6 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,25 @@ class DeviceFunctionTracker {
}
};

/// This function checks whether given DeclContext contains a topmost
/// namespace with name "sycl".
static bool isDeclaredInSYCLNamespace(const Decl *D) {
const DeclContext *DC = D->getDeclContext()->getEnclosingNamespaceContext();
const auto *ND = dyn_cast<NamespaceDecl>(DC);
// If this is not a namespace, then we are done.
if (!ND)
return false;

// While it is a namespace, find its parent scope.
while (const DeclContext *Parent = ND->getParent()) {
if (!isa<NamespaceDecl>(Parent))
break;
ND = cast<NamespaceDecl>(Parent);
}

return ND && ND->getName() == "sycl";
}

// This type does the heavy lifting for the management of device functions,
// recursive function detection, and attribute collection for a single
// kernel/external function. It walks the callgraph to find all functions that
Expand Down Expand Up @@ -770,6 +789,20 @@ class SingleDeviceFunctionTracker {
Parent.SemaRef.addFDToReachableFromSyclDevice(CurrentDecl,
CallStack.back());

// If this is a parallel_for_work_item that is declared in the
// sycl namespace, mark it with the WorkItem scope attribute.
// Note: Here, we assume that this is called from within a
// parallel_for_work_group; it is undefined to call it otherwise.
// We deliberately do not diagnose a violation.
if (CurrentDecl->getIdentifier() &&
CurrentDecl->getIdentifier()->getName() == "parallel_for_work_item" &&
isDeclaredInSYCLNamespace(CurrentDecl) &&
!CurrentDecl->hasAttr<SYCLScopeAttr>()) {
CurrentDecl->addAttr(
SYCLScopeAttr::CreateImplicit(Parent.SemaRef.getASTContext(),
SYCLScopeAttr::Level::WorkItem));
}

// We previously thought we could skip this function if we'd seen it before,
// but if we haven't seen it before in this call graph, we can end up
// missing a recursive call. SO, we have to revisit call-graphs we've
Expand Down Expand Up @@ -919,14 +952,13 @@ class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> {
// not a member of sycl::group - continue search
return true;
auto Name = Callee->getName();
if (((Name != "parallel_for_work_item") && (Name != "wait_for")) ||
if (Name != "wait_for" ||
Callee->hasAttr<SYCLScopeAttr>())
return true;
// it is a call to sycl::group::parallel_for_work_item/wait_for -
// mark the callee
// it is a call to sycl::group::wait_for - mark the callee
Callee->addAttr(
SYCLScopeAttr::CreateImplicit(Ctx, SYCLScopeAttr::Level::WorkItem));
// continue search as there can be other PFWI or wait_for calls
// continue search as there can be other wait_for calls
return true;
}

Expand Down Expand Up @@ -2968,7 +3000,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {

assert(CallOperator && "non callable object is passed as kernel obj");
// Mark the function that it "works" in a work group scope:
// NOTE: In case of parallel_for_work_item the marker call itself is
// NOTE: In case of wait_for the marker call itself is
// marked with work item scope attribute, here the '()' operator of the
// object passed as parameter is marked. This is an optimization -
// there are a lot of locals created at parallel_for_work_group
Expand All @@ -2979,7 +3011,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
if (!CallOperator->hasAttr<SYCLScopeAttr>()) {
CallOperator->addAttr(SYCLScopeAttr::CreateImplicit(
SemaRef.getASTContext(), SYCLScopeAttr::Level::WorkGroup));
// Search and mark parallel_for_work_item calls:
// Search and mark wait_for calls:
MarkWIScopeFnVisitor MarkWIScope(SemaRef.getASTContext());
MarkWIScope.TraverseDecl(CallOperator);
// Now mark local variables declared in the PFWG lambda with work group
Expand Down
4 changes: 4 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ template <int dimensions = 1>
class __SYCL_TYPE(group) group {
public:
group() = default; // fake constructor
// Dummy parallel_for_work_item function to mimic calls from
// parallel_for_work_group.
void parallel_for_work_item() {
}
};

namespace access {
Expand Down
21 changes: 21 additions & 0 deletions clang/test/CodeGenSYCL/sycl-pf-work-item.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -internal-isystem %S/Inputs -emit-llvm %s -o - | FileCheck %s
// This test checks if the parallel_for_work_item called indirecly from
// parallel_for_work_group gets the work_item_scope marker on it.
#include <sycl.hpp>

void foo(sycl::group<1> work_group) {
work_group.parallel_for_work_item();
}

int main(int argc, char **argv) {
sycl::queue q;
q.submit([&](sycl::handler &cgh) {
cgh.parallel_for_work_group(
sycl::range<1>{1}, sycl::range<1>{1024}, ([=](sycl::group<1> wGroup) {
foo(wGroup);
}));
});
return 0;
}

// CHECK: define {{.*}} void @{{.*}}sycl{{.*}}group{{.*}}parallel_for_work_item{{.*}}(ptr addrspace(4) noundef align 1 dereferenceable_or_null(1) %this) {{.*}}!work_item_scope {{.*}}!parallel_for_work_item

0 comments on commit 19bb017

Please sign in to comment.