From 0d301e65a0ee88277318092e62d867bcccc09573 Mon Sep 17 00:00:00 2001 From: HJA Bird Date: Tue, 2 Jul 2024 15:05:00 +0100 Subject: [PATCH] [DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL (#503) * OOP COMPLEX and IP REAL_REAL overload resolution is problematic * Inplace real-real overload would be selected when out-of-place complex-complex DFT was intended. * With spec update, this PR uses SFINAE to give the expected behaviour for the user. --- include/oneapi/mkl/dft/backward.hpp | 6 ++++-- include/oneapi/mkl/dft/detail/types_impl.hpp | 22 ++++++++++++++++++-- include/oneapi/mkl/dft/forward.hpp | 8 +++---- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index b63dabb28..3cd03e13b 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer &inout) } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, @@ -114,7 +115,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout, } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index e41a71469..6fd07694b 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -87,16 +87,34 @@ struct descriptor_info> { using backward_type = std::complex; }; +// Get the scalar type associated with a descriptor. +template +using descriptor_scalar_t = typename descriptor_info::scalar_type; + +template +constexpr bool is_complex_dft = false; +template +constexpr bool is_complex_dft> = true; + +template +constexpr bool is_complex = false; +template +constexpr bool is_complex> = true; + template using is_one_of = typename std::bool_constant<(std::is_same_v || ...)>; template using valid_compute_arg = typename std::bool_constant< - (std::is_same_v::scalar_type, float> && + (std::is_same_v, float> && is_one_of>::value) || - (std::is_same_v::scalar_type, double> && + (std::is_same_v, double> && is_one_of>::value)>; +template +constexpr bool valid_ip_realreal_impl = + is_complex_dft&& std::is_same_v, data_t>; + // compute the range of a reinterpreted buffer template std::size_t reinterpret_range(std::size_t size) { diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 3fcd60b1c..e43c39ce0 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer &inout) { } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, @@ -114,12 +115,12 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout, } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); - using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast(inout_re), reinterpret_cast(inout_im), dependencies); @@ -133,7 +134,6 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type * "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); - using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; return get_commit(desc)->forward_op_cc(desc, reinterpret_cast(in),