Skip to content

Commit

Permalink
[SYCL][COMPAT] match_*_over_sub_group tests. Fixed match_all to match…
Browse files Browse the repository at this point in the history
… the documented behavior.
  • Loading branch information
Alcpz committed Feb 29, 2024
1 parent 32bc5bb commit 7293dca
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sycl/include/syclcompat/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ unsigned int match_all_over_sub_group(sycl::sub_group g, unsigned member_mask,
sycl::plus<>());
bool all_equal = (reduce_result == member_mask);
*pred = is_participate & all_equal;
return all_equal * member_mask;
return (is_participate & all_equal) * member_mask;
}

namespace experimental {
Expand Down
105 changes: 105 additions & 0 deletions sycl/test-e2e/syclcompat/util/util_match_all_over_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/***************************************************************************
*
* Copyright (C) Codeplay Software Ltd.
*
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SYCLcompat API
*
* util_match_all_over_group.cpp
*
* Description:
* util_match_all_over_group tests
**************************************************************************/

// The original source was under the license below:
// ====------ UtilSelectFromSubGroup.cpp---------- -*- C++ -* ----===////
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//
// ===----------------------------------------------------------------------===//

// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
// RUN: %{run} %t.out

#include <sycl/sycl.hpp>
#include <syclcompat.hpp>

constexpr unsigned int NUM_TESTS = 3;
constexpr unsigned int SUBGROUP_SIZE = 16;
constexpr unsigned int DATA_SIZE = NUM_TESTS * SUBGROUP_SIZE;

void test_select_from_sub_group() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{SUBGROUP_SIZE};

unsigned int input[DATA_SIZE] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, // #1
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, // #2
0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; // #3
unsigned int output[DATA_SIZE];
int pred[DATA_SIZE];
unsigned int *d_input = syclcompat::malloc<unsigned int>(DATA_SIZE);
unsigned int *d_output = syclcompat::malloc<unsigned int>(DATA_SIZE);
int *d_pred = syclcompat::malloc<int>(DATA_SIZE);

unsigned int member_mask = 0x00FF;
unsigned int expected[DATA_SIZE] = {
0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF,
0, 0, 0, 0, 0, 0, 0, 0, // #1
0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF, 0x00FF,
0, 0, 0, 0, 0, 0, 0, 0, // #2
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, // #3
};
unsigned int expected_pred[DATA_SIZE] = {
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, // #1
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, // #2
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // #3
};

syclcompat::memcpy<unsigned int>(d_input, input, DATA_SIZE);
syclcompat::memset(d_output, 0, DATA_SIZE * sizeof(unsigned int));
syclcompat::memset(d_pred, 1, DATA_SIZE * sizeof(int));

sycl::queue q = syclcompat::get_default_queue();
q.parallel_for(
sycl::nd_range<1>(threads.size(), threads.size()),
[=](sycl::nd_item<1> item) [[intel::reqd_sub_group_size(SUBGROUP_SIZE)]] {
for (auto id = item.get_global_linear_id(); id < DATA_SIZE;
id += SUBGROUP_SIZE)
d_output[id] = syclcompat::match_all_over_sub_group(
item.get_sub_group(), member_mask, d_input[id], &d_pred[id]);
});
q.wait_and_throw();
syclcompat::memcpy<unsigned int>(output, d_output, DATA_SIZE);
syclcompat::memcpy<int>(pred, d_pred, DATA_SIZE);

for (int i = 0; i < DATA_SIZE; ++i) {
assert(output[i] == expected[i]);
assert(pred[i] == expected_pred[i]);
}

syclcompat::free(d_input);
syclcompat::free(d_output);
syclcompat::free(d_pred);
}

int main() {
test_select_from_sub_group();

return 0;
}
92 changes: 92 additions & 0 deletions sycl/test-e2e/syclcompat/util/util_match_any_over_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/***************************************************************************
*
* Copyright (C) Codeplay Software Ltd.
*
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SYCLcompat API
*
* util_match_any_over_group.cpp
*
* Description:
* util_match_any_over_group tests
**************************************************************************/

// The original source was under the license below:
// ====------ UtilSelectFromSubGroup.cpp---------- -*- C++ -* ----===////
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//
// ===----------------------------------------------------------------------===//

// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
// RUN: %{run} %t.out

#include <sycl/sycl.hpp>
#include <syclcompat.hpp>

#define DATA_SIZE 64
#define SUBGROUP_SIZE 16

void test_select_from_sub_group() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{DATA_SIZE};

unsigned int input[DATA_SIZE] = {
0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1,
1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
unsigned int output[DATA_SIZE];
unsigned int *d_input = syclcompat::malloc<unsigned int>(DATA_SIZE);
unsigned int *d_output = syclcompat::malloc<unsigned int>(DATA_SIZE);

unsigned int member_mask = 0x0FFF;
unsigned int expected[DATA_SIZE] = {
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
0x000F, 0x000F, 0x000F, 0x000F, 0x00F0, 0x00F0, 0x00F0, 0x00F0,
0x0F00, 0x0F00, 0x0F00, 0x0F00, 0, 0, 0, 0,
};

syclcompat::memcpy<unsigned int>(d_input, input, DATA_SIZE);
sycl::queue q = syclcompat::get_default_queue();
q.parallel_for(
sycl::nd_range<1>(grid.size() * threads.size(), threads.size()),
[=](sycl::nd_item<1> item) [[intel::reqd_sub_group_size(SUBGROUP_SIZE)]] {
auto id = item.get_global_linear_id();
d_output[id] = syclcompat::match_any_over_sub_group(
item.get_sub_group(), member_mask, d_input[id]);
});
q.wait_and_throw();
syclcompat::memcpy<unsigned int>(output, d_output, DATA_SIZE);

for (int i = 0; i < DATA_SIZE; ++i) {
assert(output[i] == expected[i]);
}

syclcompat::free(d_input);
syclcompat::free(d_output);
}

int main() {
test_select_from_sub_group();

return 0;
}

0 comments on commit 7293dca

Please sign in to comment.