Skip to content

Commit

Permalink
[DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL
Browse files Browse the repository at this point in the history
* OOP COMPLEX and IP REAL_REAL overload resolution is problematic
* Correct with SFINAE
wq
  • Loading branch information
hjabird committed May 30, 2024
1 parent f9983ee commit 6026d99
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
14 changes: 10 additions & 4 deletions include/oneapi/mkl/dft/backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout)
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand All @@ -59,7 +60,9 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_r
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
void compute_backward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,
sycl::buffer<output_type, 1> &out) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down Expand Up @@ -114,7 +117,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand All @@ -127,7 +131,9 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down
32 changes: 32 additions & 0 deletions include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ struct descriptor_info<descriptor<precision::DOUBLE, domain::COMPLEX>> {
using backward_type = std::complex<double>;
};

// Get the scalar type associated with a descriptor.
template <class descriptor_t>
using descriptor_scalar_t = typename descriptor_info<descriptor_t>::scalar_type;

template <typename T>
constexpr bool is_complex_dft = false;
template <precision Prec>
constexpr bool is_complex_dft<descriptor<Prec, domain::COMPLEX>> = true;

template <typename T>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<std::complex<T>> = true;

template <typename T, typename... Ts>
using is_one_of = typename std::bool_constant<(std::is_same_v<T, Ts> || ...)>;

Expand All @@ -97,6 +111,24 @@ using valid_compute_arg = typename std::bool_constant<
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, double> &&
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;

// For out-of-place complex-complex DFTs, are the input and output types correct? For SFINAE.
template <class descriptor_t, typename input_t, typename output_t>
constexpr bool valid_oop_iotypes = []() {
if constexpr (is_complex_dft<descriptor_t>) {
// Both input and output types must be complex, otherwise select real-real inplace overload.
return is_complex<input_t> && is_complex<output_t>;
}
else {
// I/O can be real or complex - no issues resolving overload with real-real inplace.
return valid_compute_arg<descriptor_t, input_t>::value &&
valid_compute_arg<descriptor_t, output_t>::value;
}
}();

template <class descriptor_t, typename data_t>
constexpr bool valid_ip_realreal_impl =
is_complex_dft<descriptor_t>&& std::is_same_v<descriptor_scalar_t<descriptor_t>, data_t>;

// compute the range of a reinterpreted buffer
template <typename In, typename Out>
std::size_t reinterpret_range(std::size_t size) {
Expand Down
16 changes: 10 additions & 6 deletions include/oneapi/mkl/dft/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout) {
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand All @@ -60,7 +61,9 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
void compute_forward(descriptor_type &desc, sycl::buffer<input_type, 1> &in,
sycl::buffer<output_type, 1> &out) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
Expand Down Expand Up @@ -114,26 +117,27 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
reinterpret_cast<scalar_type *>(inout_im), dependencies);
}

//Out-of-place transform
template <typename descriptor_type, typename input_type, typename output_type>
template <typename descriptor_type, typename input_type, typename output_type,
std::enable_if_t<detail::valid_oop_iotypes<descriptor_type, input_type, output_type>,
bool> = true>
sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, input_type>::value,
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;
return get_commit(desc)->forward_op_cc(desc, reinterpret_cast<fwd_type *>(in),
Expand Down

0 comments on commit 6026d99

Please sign in to comment.