Skip to content

Commit

Permalink
Merge pull request #914 from steffenlarsen/steffen/template_on_group_…
Browse files Browse the repository at this point in the history
…type

Template non-uniform group test cases on the group type
  • Loading branch information
steffenlarsen authored Jul 19, 2024
2 parents 874b5c2 + 31a36c7 commit b4e3600
Show file tree
Hide file tree
Showing 22 changed files with 1,078 additions and 1,216 deletions.
19 changes: 4 additions & 15 deletions tests/extension/oneapi_non_uniform_groups/group_barrier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,10 @@ namespace non_uniform_groups::tests {
template <int D>
class test_fence;

TEST_CASE("Non-uniform-group barriers",
"[oneapi_non_uniform_groups][group_func]") {
auto queue = once_per_unit::get_queue();

non_uniform_group_barrier<oneapi_ext::ballot_group<sycl::sub_group>>(queue);
non_uniform_group_barrier<oneapi_ext::fixed_size_group<1, sycl::sub_group>>(
queue);
non_uniform_group_barrier<oneapi_ext::fixed_size_group<2, sycl::sub_group>>(
queue);
non_uniform_group_barrier<oneapi_ext::fixed_size_group<4, sycl::sub_group>>(
queue);
non_uniform_group_barrier<oneapi_ext::fixed_size_group<8, sycl::sub_group>>(
queue);
non_uniform_group_barrier<oneapi_ext::tangle_group<sycl::sub_group>>(queue);
non_uniform_group_barrier<oneapi_ext::opportunistic_group>(queue);
TEMPLATE_LIST_TEST_CASE("Non-uniform-group barriers",
"[oneapi_non_uniform_groups][group_func]",
GroupPackTypes) {
for_all_combinations<non_uniform_group_barrier_test>(TestType{});
}

} // namespace non_uniform_groups::tests
361 changes: 183 additions & 178 deletions tests/extension/oneapi_non_uniform_groups/group_barrier.h

Large diffs are not rendered by default.

20 changes: 4 additions & 16 deletions tests/extension/oneapi_non_uniform_groups/group_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,13 @@

namespace non_uniform_groups::tests {

using BroadcastTypes = CustomTypes;

TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select",
"[oneapi_non_uniform_groups][group_func][type_list]",
BroadcastTypes) {
GroupPackTypes) {
auto queue = once_per_unit::get_queue();
broadcast_non_uniform_group<oneapi_ext::ballot_group<sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::fixed_size_group<1, sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::fixed_size_group<2, sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::fixed_size_group<4, sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::fixed_size_group<8, sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::tangle_group<sycl::sub_group>,
TestType>(queue);
broadcast_non_uniform_group<oneapi_ext::opportunistic_group, TestType>(queue);

for_all_combinations<broadcast_non_uniform_group_test>(
TestType{}, CustomTypePack{}, queue);
}

} // namespace non_uniform_groups::tests
329 changes: 168 additions & 161 deletions tests/extension/oneapi_non_uniform_groups/group_broadcast.h

Large diffs are not rendered by default.

22 changes: 6 additions & 16 deletions tests/extension/oneapi_non_uniform_groups/group_broadcast_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,14 @@

namespace non_uniform_groups::tests {

TEST_CASE("Non-uniform group broadcast and select",
"[oneapi_non_uniform_groups][group_func][fp16]") {
TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select",
"[oneapi_non_uniform_groups][group_func][fp16]",
GroupPackTypes) {
auto queue = once_per_unit::get_queue();

if (queue.get_device().has(sycl::aspect::fp16)) {
broadcast_non_uniform_group<oneapi_ext::ballot_group<sycl::sub_group>,
sycl::half>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<1, sycl::sub_group>, sycl::half>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<2, sycl::sub_group>, sycl::half>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<4, sycl::sub_group>, sycl::half>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<8, sycl::sub_group>, sycl::half>(queue);
broadcast_non_uniform_group<oneapi_ext::tangle_group<sycl::sub_group>,
sycl::half>(queue);
broadcast_non_uniform_group<oneapi_ext::opportunistic_group, sycl::half>(
queue);
for_all_combinations<broadcast_non_uniform_group_test>(
TestType{}, unnamed_type_pack<sycl::half>{}, queue);
} else {
WARN("Device does not support half precision floating point operations.");
}
Expand Down
19 changes: 5 additions & 14 deletions tests/extension/oneapi_non_uniform_groups/group_broadcast_fp64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,13 @@

namespace non_uniform_groups::tests {

TEST_CASE("Non-uniform group broadcast and select", "[group_func][fp64]") {
TEMPLATE_LIST_TEST_CASE("Non-uniform group broadcast and select",
"[group_func][fp64]", GroupPackTypes) {
auto queue = once_per_unit::get_queue();

if (queue.get_device().has(sycl::aspect::fp64)) {
broadcast_non_uniform_group<oneapi_ext::ballot_group<sycl::sub_group>,
double>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<1, sycl::sub_group>, double>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<2, sycl::sub_group>, double>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<4, sycl::sub_group>, double>(queue);
broadcast_non_uniform_group<
oneapi_ext::fixed_size_group<8, sycl::sub_group>, double>(queue);
broadcast_non_uniform_group<oneapi_ext::tangle_group<sycl::sub_group>,
double>(queue);
broadcast_non_uniform_group<oneapi_ext::opportunistic_group, double>(queue);
for_all_combinations<broadcast_non_uniform_group_test>(
TestType{}, unnamed_type_pack<double>{}, queue);
} else {
WARN("Device does not support double precision floating point operations.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,12 @@ namespace non_uniform_groups::tests {
// clang-format on
using ReduceTypes = Types;

TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions",
"[oneapi_non_uniform_groups][group_func][type_list]") {
TEMPLATE_LIST_TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions",
"[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) {
auto queue = once_per_unit::get_queue();
const auto Operators = get_op_types<CTS_TYPE>();
const auto RetType = unnamed_type_pack<CTS_TYPE>();
const auto GroupTypes = unnamed_type_pack<
oneapi_ext::ballot_group<sycl::sub_group>,
oneapi_ext::fixed_size_group<1, sycl::sub_group>,
oneapi_ext::fixed_size_group<2, sycl::sub_group>,
oneapi_ext::fixed_size_group<4, sycl::sub_group>,
oneapi_ext::fixed_size_group<8, sycl::sub_group>,
oneapi_ext::tangle_group<sycl::sub_group>,
oneapi_ext::opportunistic_group>();
const auto GroupTypes = TestType{};

if constexpr (std::is_same_v<std::remove_cv_t<CTS_TYPE>, sycl::half>) {
if (!queue.get_device().has(sycl::aspect::fp16))
Expand All @@ -60,20 +53,13 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint reduce functions",

TEMPLATE_LIST_TEST_CASE(
CTS_TYPE_NAME + " non-uniform group joint reduce functions with init",
"[oneapi_non_uniform_groups][group_func][type_list]", ReduceTypes) {
"[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) {
auto queue = once_per_unit::get_queue();

const auto Operators = get_op_types<CTS_TYPE>();
const auto RetType = unnamed_type_pack<CTS_TYPE>();
const auto ReducedType = unnamed_type_pack<TestType>();
const auto GroupTypes = unnamed_type_pack<
oneapi_ext::ballot_group<sycl::sub_group>,
oneapi_ext::fixed_size_group<1, sycl::sub_group>,
oneapi_ext::fixed_size_group<2, sycl::sub_group>,
oneapi_ext::fixed_size_group<4, sycl::sub_group>,
oneapi_ext::fixed_size_group<8, sycl::sub_group>,
oneapi_ext::tangle_group<sycl::sub_group>,
oneapi_ext::opportunistic_group>();
const auto ReducedType = Types{};
const auto GroupTypes = TestType{};

if constexpr (std::is_same_v<std::remove_cv_t<CTS_TYPE>, sycl::half>) {
if (!queue.get_device().has(sycl::aspect::fp16))
Expand Down
32 changes: 9 additions & 23 deletions tests/extension/oneapi_non_uniform_groups/group_joint_scan.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,14 @@

namespace non_uniform_groups::tests {

using TestType = unnamed_type_pack<CTS_TYPE>;
using CurrentType = unnamed_type_pack<CTS_TYPE>;
using ScanTypes = Types;
#endif // !SYCL_CTS_COMPILING_WITH_HIPSYCL

TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions",
"[oneapi_non_uniform_groups][group_func][type_list]"){
TEMPLATE_LIST_TEST_CASE(
CTS_TYPE_NAME + " non-uniform group joint scan functions",
"[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes){
auto queue = once_per_unit::get_queue();
const auto GroupTypes = unnamed_type_pack<
oneapi_ext::ballot_group<sycl::sub_group>,
oneapi_ext::fixed_size_group<1, sycl::sub_group>,
oneapi_ext::fixed_size_group<2, sycl::sub_group>,
oneapi_ext::fixed_size_group<4, sycl::sub_group>,
oneapi_ext::fixed_size_group<8, sycl::sub_group>,
oneapi_ext::tangle_group<sycl::sub_group>,
oneapi_ext::opportunistic_group>();

if constexpr (std::is_same_v<std::remove_cv_t<CTS_TYPE>, sycl::half>) {
if (!queue.get_device().has(sycl::aspect::fp16))
Expand All @@ -56,21 +49,14 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions",
"operations.");
}

for_all_combinations<invoke_joint_scan_group>(GroupTypes, TestType{},
for_all_combinations<invoke_joint_scan_group>(TestType{}, CurrentType{},
ScanTypes{}, queue);
};

TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions with init",
"[oneapi_non_uniform_groups][group_func][type_list]"){
TEMPLATE_LIST_TEST_CASE(
CTS_TYPE_NAME + " non-uniform group joint scan functions with init",
"[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes){
auto queue = once_per_unit::get_queue();
const auto GroupTypes = unnamed_type_pack<
oneapi_ext::ballot_group<sycl::sub_group>,
oneapi_ext::fixed_size_group<1, sycl::sub_group>,
oneapi_ext::fixed_size_group<2, sycl::sub_group>,
oneapi_ext::fixed_size_group<4, sycl::sub_group>,
oneapi_ext::fixed_size_group<8, sycl::sub_group>,
oneapi_ext::tangle_group<sycl::sub_group>,
oneapi_ext::opportunistic_group>();

if constexpr (std::is_same_v<std::remove_cv_t<CTS_TYPE>, sycl::half>) {
if (!queue.get_device().has(sycl::aspect::fp16))
Expand All @@ -85,7 +71,7 @@ TEST_CASE(CTS_TYPE_NAME + " non-uniform group joint scan functions with init",
}

for_all_combinations<invoke_init_joint_scan_group>(
GroupTypes, TestType{}, ScanTypes{}, ScanTypes{}, queue);
TestType{}, CurrentType{}, ScanTypes{}, ScanTypes{}, queue);
};

} // namespace non_uniform_groups::tests
58 changes: 12 additions & 46 deletions tests/extension/oneapi_non_uniform_groups/group_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,61 +23,27 @@
namespace non_uniform_groups::tests {

// use wide types to exclude truncation of init values
using WideTypes = std::tuple<int32_t, uint32_t, int64_t, uint64_t, float>;
static const auto wide_types =
named_type_pack<int32_t, uint32_t, int64_t, uint64_t, float>::generate(
"int32_t", "uint32_t", "int64_t", "uint64_t", "float");

TEMPLATE_LIST_TEST_CASE("Non-uniform group joint of bool functions",
"[oneapi_non_uniform_groups][group_func][type_list]",
WideTypes) {
auto queue = once_per_unit::get_queue();
joint_of_group<oneapi_ext::ballot_group<sycl::sub_group>, TestType>(queue);
joint_of_group<oneapi_ext::fixed_size_group<1, sycl::sub_group>, TestType>(
queue);
joint_of_group<oneapi_ext::fixed_size_group<2, sycl::sub_group>, TestType>(
queue);
joint_of_group<oneapi_ext::fixed_size_group<4, sycl::sub_group>, TestType>(
queue);
joint_of_group<oneapi_ext::fixed_size_group<8, sycl::sub_group>, TestType>(
queue);
joint_of_group<oneapi_ext::tangle_group<sycl::sub_group>, TestType>(queue);
joint_of_group<oneapi_ext::opportunistic_group, TestType>(queue);
GroupPackTypes) {
for_all_combinations<joint_of_group_test>(TestType{}, wide_types);
}

TEMPLATE_LIST_TEST_CASE(
"Non-uniform group of bool functions with predicate functions",
"[oneapi_non_uniform_groups][group_func][type_list]", WideTypes) {
auto queue = once_per_unit::get_queue();
predicate_function_of_non_uniform_group<
oneapi_ext::ballot_group<sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<1, sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<2, sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<4, sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<8, sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<
oneapi_ext::tangle_group<sycl::sub_group>, TestType>(queue);
predicate_function_of_non_uniform_group<oneapi_ext::opportunistic_group,
TestType>(queue);
"[oneapi_non_uniform_groups][group_func][type_list]", GroupPackTypes) {
for_all_combinations<predicate_function_of_non_uniform_group_test>(
TestType{}, wide_types);
}

TEST_CASE("Non-uniform group of bool functions",
"[oneapi_non_uniform_groups][group_func]") {
auto queue = once_per_unit::get_queue();
bool_function_of_non_uniform_group<oneapi_ext::ballot_group<sycl::sub_group>>(
queue);
bool_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<1, sycl::sub_group>>(queue);
bool_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<2, sycl::sub_group>>(queue);
bool_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<4, sycl::sub_group>>(queue);
bool_function_of_non_uniform_group<
oneapi_ext::fixed_size_group<8, sycl::sub_group>>(queue);
bool_function_of_non_uniform_group<oneapi_ext::tangle_group<sycl::sub_group>>(
queue);
bool_function_of_non_uniform_group<oneapi_ext::opportunistic_group>(queue);
TEMPLATE_LIST_TEST_CASE("Non-uniform group of bool functions",
"[oneapi_non_uniform_groups][group_func]",
GroupPackTypes) {
for_all_combinations<bool_function_of_non_uniform_group_test>(TestType{});
}

} // namespace non_uniform_groups::tests
Loading

0 comments on commit b4e3600

Please sign in to comment.