From 8184818a61e12c0f656a12ea45f55819b2a72a31 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Thu, 30 May 2024 16:32:07 +0100 Subject: [PATCH 1/3] [DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL * OOP COMPLEX and IP REAL_REAL overload resolution is problematic * Correct with SFINAE wq --- include/oneapi/mkl/dft/backward.hpp | 14 ++++++--- include/oneapi/mkl/dft/detail/types_impl.hpp | 32 ++++++++++++++++++++ include/oneapi/mkl/dft/forward.hpp | 16 ++++++---- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index b63dabb28..5eafe48c5 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer &inout) } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, @@ -59,7 +60,9 @@ void compute_backward(descriptor_type &desc, sycl::buffer &inout_r } //Out-of-place transform -template +template , + bool> = true> void compute_backward(descriptor_type &desc, sycl::buffer &in, sycl::buffer &out) { static_assert(detail::valid_compute_arg::value, @@ -114,7 +117,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout, } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, @@ -127,7 +131,9 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty } //Out-of-place transform -template +template , + bool> = true> sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index e41a71469..75486530a 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -87,6 +87,20 @@ struct descriptor_info> { using backward_type = std::complex; }; +// Get the scalar type associated with a descriptor. +template +using descriptor_scalar_t = typename descriptor_info::scalar_type; + +template +constexpr bool is_complex_dft = false; +template +constexpr bool is_complex_dft> = true; + +template +constexpr bool is_complex = false; +template +constexpr bool is_complex> = true; + template using is_one_of = typename std::bool_constant<(std::is_same_v || ...)>; @@ -97,6 +111,24 @@ using valid_compute_arg = typename std::bool_constant< (std::is_same_v::scalar_type, double> && is_one_of>::value)>; +// For out-of-place complex-complex DFTs, are the input and output types correct? For SFINAE. +template +constexpr bool valid_oop_iotypes = []() { + if constexpr (is_complex_dft) { + // Both input and output types must be complex, otherwise select real-real inplace overload. + return is_complex && is_complex; + } + else { + // I/O can be real or complex - no issues resolving overload with real-real inplace. + return valid_compute_arg::value && + valid_compute_arg::value; + } +}(); + +template +constexpr bool valid_ip_realreal_impl = + is_complex_dft&& std::is_same_v, data_t>; + // compute the range of a reinterpreted buffer template std::size_t reinterpret_range(std::size_t size) { diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 3fcd60b1c..353a5a89f 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer &inout) { } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, @@ -60,7 +61,9 @@ void compute_forward(descriptor_type &desc, sycl::buffer &inout_re } //Out-of-place transform -template +template , + bool> = true> void compute_forward(descriptor_type &desc, sycl::buffer &in, sycl::buffer &out) { static_assert(detail::valid_compute_arg::value, @@ -114,26 +117,27 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout, } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); - using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast(inout_re), reinterpret_cast(inout_im), dependencies); } //Out-of-place transform -template +template , + bool> = true> sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); - using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; return get_commit(desc)->forward_op_cc(desc, reinterpret_cast(in), From 5ad57d36a85d7b2ebbcf8e0a136127c404095e27 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Fri, 31 May 2024 14:01:13 +0100 Subject: [PATCH 2/3] Use descriptor_scalar_t elsewhere --- include/oneapi/mkl/dft/detail/types_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index 75486530a..131e27c4d 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -106,9 +106,9 @@ using is_one_of = typename std::bool_constant<(std::is_same_v || ...)>; template using valid_compute_arg = typename std::bool_constant< - (std::is_same_v::scalar_type, float> && + (std::is_same_v, float> && is_one_of>::value) || - (std::is_same_v::scalar_type, double> && + (std::is_same_v, double> && is_one_of>::value)>; // For out-of-place complex-complex DFTs, are the input and output types correct? For SFINAE. From 516c7fa16a444a9ed400cfbb2b118cad970831f3 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Thu, 6 Jun 2024 11:43:41 +0100 Subject: [PATCH 3/3] Reenable OOP complex-complex version for float input --- include/oneapi/mkl/dft/backward.hpp | 8 ++------ include/oneapi/mkl/dft/detail/types_impl.hpp | 14 -------------- include/oneapi/mkl/dft/forward.hpp | 8 ++------ 3 files changed, 4 insertions(+), 26 deletions(-) diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index 5eafe48c5..3cd03e13b 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -60,9 +60,7 @@ void compute_backward(descriptor_type &desc, sycl::buffer &inout_r } //Out-of-place transform -template , - bool> = true> +template void compute_backward(descriptor_type &desc, sycl::buffer &in, sycl::buffer &out) { static_assert(detail::valid_compute_arg::value, @@ -131,9 +129,7 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty } //Out-of-place transform -template , - bool> = true> +template sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index 131e27c4d..6fd07694b 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -111,20 +111,6 @@ using valid_compute_arg = typename std::bool_constant< (std::is_same_v, double> && is_one_of>::value)>; -// For out-of-place complex-complex DFTs, are the input and output types correct? For SFINAE. -template -constexpr bool valid_oop_iotypes = []() { - if constexpr (is_complex_dft) { - // Both input and output types must be complex, otherwise select real-real inplace overload. - return is_complex && is_complex; - } - else { - // I/O can be real or complex - no issues resolving overload with real-real inplace. - return valid_compute_arg::value && - valid_compute_arg::value; - } -}(); - template constexpr bool valid_ip_realreal_impl = is_complex_dft&& std::is_same_v, data_t>; diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 353a5a89f..e43c39ce0 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -61,9 +61,7 @@ void compute_forward(descriptor_type &desc, sycl::buffer &inout_re } //Out-of-place transform -template , - bool> = true> +template void compute_forward(descriptor_type &desc, sycl::buffer &in, sycl::buffer &out) { static_assert(detail::valid_compute_arg::value, @@ -129,9 +127,7 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_typ } //Out-of-place transform -template , - bool> = true> +template sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value,