Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Align root_group with spec #12653

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/root_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,20 @@ template <> struct detail::PropertyToKind<use_root_sync_key> {
template <>
struct detail::IsCompileTimeProperty<use_root_sync_key> : std::true_type {};

enum class execution_scope {
work_item,
sub_group,
work_group,
root_group,
KornevNikita marked this conversation as resolved.
Show resolved Hide resolved
};

template <int Dimensions> class root_group {
template <execution_scope Scope, typename RetTy>
using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::sub_group ||
Scope == execution_scope::work_group),
RetTy>;

public:
using id_type = id<Dimensions>;
using range_type = range<Dimensions>;
Expand Down Expand Up @@ -78,6 +91,69 @@ template <int Dimensions> class root_group {

bool leader() const { return get_local_id() == 0; };

template <execution_scope Scope>
std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::work_group),
id<Dimensions>>
get_id() const {
if constexpr (Scope == execution_scope::work_item)
return it.get_global_id();
else if constexpr (Scope == execution_scope::work_group)
return it.get_group().get_group_id();
}

template <execution_scope Scope>
std::enable_if_t<Scope == execution_scope::sub_group, id<1>> get_id() const {
return get_linear_id<execution_scope::sub_group>();
}

template <execution_scope Scope>
checkScopeTy<Scope, size_t> get_linear_id() const {
if constexpr (Scope == execution_scope::work_item) {
return it.get_global_linear_id();
} else if constexpr (Scope == execution_scope::sub_group) {
size_t WIId = it.get_global_linear_id();
size_t SGSize = it.get_sub_group().get_local_linear_range();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be get_max_local_range()[0]. @Pennycook , do I remember correctly what happens when "last" sub-group is partially masked out?

If so, we'd need a test for this as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but can you do this without a division?

it.get_group().get_group_linear_id() * it.get_sub_group().get_group_linear_range() + it.get_sub_group().get_group_linear_id()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just occurred to me that this would give a different linearization order, and that the linearization order implied by these functions might not be well-defined yet... Different people might expect them to do different things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, should we proceed with the suggestion above or it requires some clarification to the spec first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need some clarification. We should talk to some other folks and figure out what the right thing to do here is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pennycook in general, do we even want get_linear_id to accept execution_scope::sub_group?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue affects get_id for sub-groups as well.

Because this new API is intended to be a general querying mechanism for all groups, we need to think this through carefully. I'm reluctant to start introducing corner-cases until we understand what's possible. If this API doesn't support sub-groups, it might not support other future group types.

return WIId / SGSize;
} else if constexpr (Scope == execution_scope::work_group) {
return it.get_group().get_group_linear_id();
}
}

template <execution_scope Scope>
std::enable_if_t<(Scope == execution_scope::work_item ||
Scope == execution_scope::work_group),
range<Dimensions>>
get_range() const {
if constexpr (Scope == execution_scope::work_item)
return it.get_global_range();
else if constexpr (Scope == execution_scope::work_group)
return it.get_group().get_group_range();
}

template <execution_scope Scope>
std::enable_if_t<Scope == execution_scope::sub_group, range<1>>
get_range() const {
return get_linear_range<execution_scope::sub_group>();
}

template <execution_scope Scope>
checkScopeTy<Scope, size_t> get_linear_range() const {
if constexpr (Scope == execution_scope::work_item) {
range<Dimensions> Range = it.get_global_range();
size_t linRange = 1;
for (int i = 0; i < Dimensions; ++i)
linRange *= Range[i];
return linRange;
} else if constexpr (Scope == execution_scope::sub_group) {
uint32_t NumWG = it.get_group().get_group_linear_range();
uint32_t NumSGPerWG = it.get_sub_group().get_group_linear_range();
return NumWG * NumSGPerWG;
} else if constexpr (Scope == execution_scope::work_group) {
return it.get_group().get_group_linear_range();
}
}

private:
friend root_group<Dimensions>
nd_item<Dimensions>::ext_oneapi_get_root_group() const;
Expand Down
53 changes: 52 additions & 1 deletion sycl/test-e2e/GroupAlgorithm/root_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void testRootGroupFunctions() {
const auto props = sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::use_root_sync};

constexpr int testCount = 10;
constexpr int testCount = 22;
sycl::buffer<bool> testResultsBuf{sycl::range{testCount}};
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
q.submit([&](sycl::handler &h) {
Expand Down Expand Up @@ -111,6 +111,57 @@ void testRootGroupFunctions() {
sycl::sub_group>,
"get_child_group(sycl::group) must return a sycl::sub_group");
}

auto SG = it.get_sub_group();
size_t SGSize = SG.get_local_linear_range();
using execution_scope =
sycl::ext::oneapi::experimental::execution_scope;

if (root.leader()) {
size_t NumSGPerWG = SG.get_group_linear_range();

testResults[10] =
root.template get_range<execution_scope::work_group>()[0] ==
maxWGs;
testResults[11] =
root.template get_range<execution_scope::sub_group>()[0] ==
NumSGPerWG * maxWGs;
testResults[12] =
root.template get_range<execution_scope::work_item>()[0] ==
maxWGs * WorkGroupSize;

testResults[13] =
root.template get_linear_range<execution_scope::work_group>() ==
maxWGs;
testResults[14] =
root.template get_linear_range<execution_scope::sub_group>() ==
NumSGPerWG * maxWGs;
testResults[15] =
root.template get_linear_range<execution_scope::work_item>() ==
maxWGs * WorkGroupSize;
}

if (root.get_local_id() == 3) {
testResults[16] =
root.template get_id<execution_scope::work_group>() ==
it.get_global_linear_id() / WorkGroupSize;
testResults[17] =
root.template get_id<execution_scope::sub_group>() ==
it.get_global_linear_id() / SGSize;
testResults[18] =
root.template get_id<execution_scope::work_item>() ==
it.get_global_linear_id();

testResults[19] =
root.template get_linear_id<execution_scope::work_group>() ==
it.get_global_linear_id() / WorkGroupSize;
testResults[20] =
root.template get_linear_id<execution_scope::sub_group>() ==
it.get_global_linear_id() / SGSize;
testResults[21] =
root.template get_linear_id<execution_scope::work_item>() ==
it.get_global_linear_id();
}
});
});
sycl::host_accessor testResults{testResultsBuf};
Expand Down
Loading