Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT] Add static_assert to check types passed into compute_foward and compute_backward #502

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion include/oneapi/mkl/dft/backward.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
* Copyright 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,6 +49,7 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_r
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");
static_assert(!detail::is_complex_arg<data_type>::value, "expected real type for data_type");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
auto type_corrected_inout_re = inout_re.template reinterpret<scalar_type, 1>(
Expand All @@ -69,6 +70,13 @@ void compute_backward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;

// If the DFT is COMPLEX, the input and output types are expected to be complex.
static_assert(
!std::is_same_v<fwd_type, bwd_type> || (detail::is_complex_arg<input_type>::value &&
detail::is_complex_arg<output_type>::value),
"expected std::complex input_type and output_type");

auto type_corrected_in = in.template reinterpret<bwd_type, 1>(
detail::reinterpret_range<input_type, bwd_type>(in.size()));
auto type_corrected_out = out.template reinterpret<fwd_type, 1>(
Expand All @@ -85,6 +93,9 @@ void compute_backward(descriptor_type &desc, sycl::buffer<input_type, 1> &in_re,
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");
static_assert(
!detail::is_complex_arg<input_type>::value && !detail::is_complex_arg<output_type>::value,
"expected input_type and output_type to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
auto type_corrected_in_re = in_re.template reinterpret<scalar_type, 1>(
Expand Down Expand Up @@ -119,6 +130,7 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");
static_assert(!detail::is_complex_arg<data_type>::value, "data_type is expected to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->backward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
Expand All @@ -137,6 +149,13 @@ sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;

// If the DFT is COMPLEX, the input and output types are expected to be complex.
static_assert(
!std::is_same_v<fwd_type, bwd_type> || (detail::is_complex_arg<input_type>::value &&
detail::is_complex_arg<output_type>::value),
"expected std::complex input_type and output_type");

return get_commit(desc)->backward_op_cc(desc, reinterpret_cast<bwd_type *>(in),
reinterpret_cast<fwd_type *>(out), dependencies);
}
Expand All @@ -150,6 +169,9 @@ sycl::event compute_backward(descriptor_type &desc, input_type *in_re, input_typ
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");
static_assert(
!detail::is_complex_arg<input_type>::value && !detail::is_complex_arg<output_type>::value,
"expected input_type and output_type to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->backward_op_rr(desc, reinterpret_cast<scalar_type *>(in_re),
Expand Down
8 changes: 7 additions & 1 deletion include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2020-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -97,6 +97,12 @@ using valid_compute_arg = typename std::bool_constant<
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, double> &&
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;

template <typename T>
class is_complex_arg : public std::false_type {};

template <typename T>
class is_complex_arg<std::complex<T>> : public std::true_type {};

// compute the range of a reinterpreted buffer
template <typename In, typename Out>
std::size_t reinterpret_range(std::size_t size) {
Expand Down
24 changes: 23 additions & 1 deletion include/oneapi/mkl/dft/forward.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
* Copyright 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,7 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");
static_assert(!detail::is_complex_arg<data_type>::value, "expected real type for data_type");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
auto type_corrected_inout_re = inout_re.template reinterpret<scalar_type, 1>(
Expand All @@ -70,6 +71,13 @@ void compute_forward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;

// If the DFT is COMPLEX, the input and output types are expected to be complex.
static_assert(
!std::is_same_v<fwd_type, bwd_type> || (detail::is_complex_arg<input_type>::value &&
detail::is_complex_arg<output_type>::value),
"expected std::complex input_type and output_type");

auto type_corrected_in = in.template reinterpret<fwd_type, 1>(
detail::reinterpret_range<input_type, fwd_type>(in.size()));
auto type_corrected_out = out.template reinterpret<bwd_type, 1>(
Expand All @@ -86,6 +94,9 @@ void compute_forward(descriptor_type &desc, sycl::buffer<input_type, 1> &in_re,
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");
static_assert(
!detail::is_complex_arg<input_type>::value && !detail::is_complex_arg<output_type>::value,
"expected input_type and output_type to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
auto type_corrected_in_re = in_re.template reinterpret<scalar_type, 1>(
Expand Down Expand Up @@ -119,6 +130,7 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_typ
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");
static_assert(!detail::is_complex_arg<data_type>::value, "data_type is expected to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
Expand All @@ -136,6 +148,13 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;

// If the DFT is COMPLEX, the input and output types are expected to be complex.
static_assert(
!std::is_same_v<fwd_type, bwd_type> || (detail::is_complex_arg<input_type>::value &&
detail::is_complex_arg<output_type>::value),
"expected std::complex input_type and output_type");

return get_commit(desc)->forward_op_cc(desc, reinterpret_cast<fwd_type *>(in),
reinterpret_cast<bwd_type *>(out), dependencies);
}
Expand All @@ -149,6 +168,9 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in_re, input_type
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");
static_assert(
!detail::is_complex_arg<input_type>::value && !detail::is_complex_arg<output_type>::value,
"expected input_type and output_type to be real");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->forward_op_rr(desc, reinterpret_cast<scalar_type *>(in_re),
Expand Down