From 0bf51451e0358c256cba68a8c37c8a402775099b Mon Sep 17 00:00:00 2001 From: KornevNikita Date: Wed, 7 Feb 2024 03:53:11 -0800 Subject: [PATCH 1/2] [SYCL] Align root_group with spec Spec: https://github.com/intel/llvm/pull/12643 --- .../ext/oneapi/experimental/root_group.hpp | 76 +++++++++++++++++++ sycl/test-e2e/GroupAlgorithm/root_group.cpp | 53 ++++++++++++- 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp index 8cbc88ccf6194..9f1abb84e30ad 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp @@ -42,6 +42,13 @@ template <> struct detail::PropertyToKind { template <> struct detail::IsCompileTimeProperty : std::true_type {}; +enum class execution_scope { + work_item, + sub_group, + work_group, + root_group, +}; + template class root_group { public: using id_type = id; @@ -78,6 +85,75 @@ template class root_group { bool leader() const { return get_local_id() == 0; }; + template + using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::sub_group || + Scope == execution_scope::work_group), + RetTy>; + + template + std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::work_group), + id> + get_id() const { + if constexpr (Scope == execution_scope::work_item) + return it.get_global_id(); + else if constexpr (Scope == execution_scope::work_group) + return it.get_group().get_group_id(); + } + + template + std::enable_if_t> get_id() const { + return get_linear_id(); + } + + template + checkScopeTy get_linear_id() const { + if constexpr (Scope == execution_scope::work_item) { + return it.get_global_linear_id(); + } else if constexpr (Scope == execution_scope::sub_group) { + size_t WIId = it.get_global_linear_id(); + size_t SGSize = it.get_sub_group().get_local_linear_range(); + return WIId / SGSize; + } else if constexpr (Scope == execution_scope::work_group) { + return it.get_group().get_group_linear_id(); + } + } + + template + std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::work_group), + range> + get_range() const { + if constexpr (Scope == execution_scope::work_item) + return it.get_global_range(); + else if constexpr (Scope == execution_scope::work_group) + return it.get_group().get_group_range(); + } + + template + std::enable_if_t> + get_range() const { + return get_linear_range(); + } + + template + checkScopeTy get_linear_range() const { + if constexpr (Scope == execution_scope::work_item) { + range Range = it.get_global_range(); + size_t linRange = 1; + for (int i = 0; i < Dimensions; ++i) + linRange *= Range[i]; + return linRange; + } else if constexpr (Scope == execution_scope::sub_group) { + uint32_t NumWG = it.get_group().get_group_linear_range(); + uint32_t NumSGPerWG = it.get_sub_group().get_group_linear_range(); + return NumWG * NumSGPerWG; + } else if constexpr (Scope == execution_scope::work_group) { + return it.get_group().get_group_linear_range(); + } + } + private: friend root_group nd_item::ext_oneapi_get_root_group() const; diff --git a/sycl/test-e2e/GroupAlgorithm/root_group.cpp b/sycl/test-e2e/GroupAlgorithm/root_group.cpp index ba0c49fa68bf7..3a1e78b1d3ce6 100644 --- a/sycl/test-e2e/GroupAlgorithm/root_group.cpp +++ b/sycl/test-e2e/GroupAlgorithm/root_group.cpp @@ -75,7 +75,7 @@ void testRootGroupFunctions() { const auto props = sycl::ext::oneapi::experimental::properties{ sycl::ext::oneapi::experimental::use_root_sync}; - constexpr int testCount = 10; + constexpr int testCount = 22; bool *testResults = sycl::malloc_shared(testCount, q); const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize}; q.parallel_for( @@ -107,6 +107,57 @@ void testRootGroupFunctions() { sycl::sub_group>, "get_child_group(sycl::group) must return a sycl::sub_group"); } + + auto SG = it.get_sub_group(); + size_t SGSize = SG.get_local_linear_range(); + using execution_scope = + sycl::ext::oneapi::experimental::execution_scope; + + if (root.leader()) { + size_t NumSGPerWG = SG.get_group_linear_range(); + + testResults[10] = + root.template get_range()[0] == + maxWGs; + testResults[11] = + root.template get_range()[0] == + NumSGPerWG * maxWGs; + testResults[12] = + root.template get_range()[0] == + maxWGs * WorkGroupSize; + + testResults[13] = + root.template get_linear_range() == + maxWGs; + testResults[14] = + root.template get_linear_range() == + NumSGPerWG * maxWGs; + testResults[15] = + root.template get_linear_range() == + maxWGs * WorkGroupSize; + } + + if (root.get_local_id() == 3) { + testResults[16] = + root.template get_id() == + it.get_global_linear_id() / WorkGroupSize; + testResults[17] = + root.template get_id() == + it.get_global_linear_id() / SGSize; + testResults[18] = + root.template get_id() == + it.get_global_linear_id(); + + testResults[19] = + root.template get_linear_id() == + it.get_global_linear_id() / WorkGroupSize; + testResults[20] = + root.template get_linear_id() == + it.get_global_linear_id() / SGSize; + testResults[21] = + root.template get_linear_id() == + it.get_global_linear_id(); + } }); q.wait(); From 0b2de7458bd7fa236bf653d59b855ccd28edbf1f Mon Sep 17 00:00:00 2001 From: KornevNikita Date: Thu, 8 Feb 2024 03:34:21 -0800 Subject: [PATCH 2/2] Move checkScopeTy to private --- .../sycl/ext/oneapi/experimental/root_group.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp index 9f1abb84e30ad..36f0d24a0525a 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp @@ -50,6 +50,12 @@ enum class execution_scope { }; template class root_group { + template + using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::sub_group || + Scope == execution_scope::work_group), + RetTy>; + public: using id_type = id; using range_type = range; @@ -85,12 +91,6 @@ template class root_group { bool leader() const { return get_local_id() == 0; }; - template - using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item || - Scope == execution_scope::sub_group || - Scope == execution_scope::work_group), - RetTy>; - template std::enable_if_t<(Scope == execution_scope::work_item || Scope == execution_scope::work_group),