Skip to content

Commit

Permalink
[SYCL] Align root_group with spec
Browse files Browse the repository at this point in the history
  • Loading branch information
KornevNikita committed Feb 7, 2024
1 parent 87b55fd commit 0bf5145
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
76 changes: 76 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/root_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ template <> struct detail::PropertyToKind<use_root_sync_key> {
template <>
struct detail::IsCompileTimeProperty<use_root_sync_key> : std::true_type {};

enum class execution_scope {
work_item,
sub_group,
work_group,
root_group,
};

template <int Dimensions> class root_group {
public:
using id_type = id<Dimensions>;
Expand Down Expand Up @@ -78,6 +85,75 @@ template <int Dimensions> class root_group {

bool leader() const { return get_local_id() == 0; };

template <execution_scope Scope, typename RetTy>
using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::sub_group ||
Scope == execution_scope::work_group),
RetTy>;

template <execution_scope Scope>
std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::work_group),
id<Dimensions>>
get_id() const {
if constexpr (Scope == execution_scope::work_item)
return it.get_global_id();
else if constexpr (Scope == execution_scope::work_group)
return it.get_group().get_group_id();
}

template <execution_scope Scope>
std::enable_if_t<Scope == execution_scope::sub_group, id<1>> get_id() const {
return get_linear_id<execution_scope::sub_group>();
}

template <execution_scope Scope>
checkScopeTy<Scope, size_t> get_linear_id() const {
if constexpr (Scope == execution_scope::work_item) {
return it.get_global_linear_id();
} else if constexpr (Scope == execution_scope::sub_group) {
size_t WIId = it.get_global_linear_id();
size_t SGSize = it.get_sub_group().get_local_linear_range();
return WIId / SGSize;
} else if constexpr (Scope == execution_scope::work_group) {
return it.get_group().get_group_linear_id();
}
}

template <execution_scope Scope>
std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::work_group),
range<Dimensions>>
get_range() const {
if constexpr (Scope == execution_scope::work_item)
return it.get_global_range();
else if constexpr (Scope == execution_scope::work_group)
return it.get_group().get_group_range();
}

template <execution_scope Scope>
std::enable_if_t<Scope == execution_scope::sub_group, range<1>>
get_range() const {
return get_linear_range<execution_scope::sub_group>();
}

template <execution_scope Scope>
checkScopeTy<Scope, size_t> get_linear_range() const {
if constexpr (Scope == execution_scope::work_item) {
range<Dimensions> Range = it.get_global_range();
size_t linRange = 1;
for (int i = 0; i < Dimensions; ++i)
linRange *= Range[i];
return linRange;
} else if constexpr (Scope == execution_scope::sub_group) {
uint32_t NumWG = it.get_group().get_group_linear_range();
uint32_t NumSGPerWG = it.get_sub_group().get_group_linear_range();
return NumWG * NumSGPerWG;
} else if constexpr (Scope == execution_scope::work_group) {
return it.get_group().get_group_linear_range();
}
}

private:
friend root_group<Dimensions>
nd_item<Dimensions>::ext_oneapi_get_root_group() const;
Expand Down
53 changes: 52 additions & 1 deletion sycl/test-e2e/GroupAlgorithm/root_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void testRootGroupFunctions() {
const auto props = sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::use_root_sync};

constexpr int testCount = 10;
constexpr int testCount = 22;
bool *testResults = sycl::malloc_shared<bool>(testCount, q);
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
q.parallel_for<class RootGroupFunctionsKernel>(
Expand Down Expand Up @@ -107,6 +107,57 @@ void testRootGroupFunctions() {
sycl::sub_group>,
"get_child_group(sycl::group) must return a sycl::sub_group");
}

auto SG = it.get_sub_group();
size_t SGSize = SG.get_local_linear_range();
using execution_scope =
sycl::ext::oneapi::experimental::execution_scope;

if (root.leader()) {
size_t NumSGPerWG = SG.get_group_linear_range();

testResults[10] =
root.template get_range<execution_scope::work_group>()[0] ==
maxWGs;
testResults[11] =
root.template get_range<execution_scope::sub_group>()[0] ==
NumSGPerWG * maxWGs;
testResults[12] =
root.template get_range<execution_scope::work_item>()[0] ==
maxWGs * WorkGroupSize;

testResults[13] =
root.template get_linear_range<execution_scope::work_group>() ==
maxWGs;
testResults[14] =
root.template get_linear_range<execution_scope::sub_group>() ==
NumSGPerWG * maxWGs;
testResults[15] =
root.template get_linear_range<execution_scope::work_item>() ==
maxWGs * WorkGroupSize;
}

if (root.get_local_id() == 3) {
testResults[16] =
root.template get_id<execution_scope::work_group>() ==
it.get_global_linear_id() / WorkGroupSize;
testResults[17] =
root.template get_id<execution_scope::sub_group>() ==
it.get_global_linear_id() / SGSize;
testResults[18] =
root.template get_id<execution_scope::work_item>() ==
it.get_global_linear_id();

testResults[19] =
root.template get_linear_id<execution_scope::work_group>() ==
it.get_global_linear_id() / WorkGroupSize;
testResults[20] =
root.template get_linear_id<execution_scope::sub_group>() ==
it.get_global_linear_id() / SGSize;
testResults[21] =
root.template get_linear_id<execution_scope::work_item>() ==
it.get_global_linear_id();
}
});
q.wait();

Expand Down

0 comments on commit 0bf5145

Please sign in to comment.