Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL #503

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/oneapi/mkl/dft/backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout)
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down Expand Up @@ -114,7 +115,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down
22 changes: 20 additions & 2 deletions include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,34 @@ struct descriptor_info<descriptor<precision::DOUBLE, domain::COMPLEX>> {
using backward_type = std::complex<double>;
};

// Get the scalar type associated with a descriptor.
template <class descriptor_t>
using descriptor_scalar_t = typename descriptor_info<descriptor_t>::scalar_type;

template <typename T>
constexpr bool is_complex_dft = false;
template <precision Prec>
constexpr bool is_complex_dft<descriptor<Prec, domain::COMPLEX>> = true;

template <typename T>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<std::complex<T>> = true;

template <typename T, typename... Ts>
using is_one_of = typename std::bool_constant<(std::is_same_v<T, Ts> || ...)>;

template <typename descriptor_type, typename T>
using valid_compute_arg = typename std::bool_constant<
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, float> &&
(std::is_same_v<descriptor_scalar_t<descriptor_type>, float> &&
is_one_of<T, float, sycl::float2, sycl::float4, std::complex<float>>::value) ||
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, double> &&
(std::is_same_v<descriptor_scalar_t<descriptor_type>, double> &&
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;

template <class descriptor_t, typename data_t>
constexpr bool valid_ip_realreal_impl =
is_complex_dft<descriptor_t>&& std::is_same_v<descriptor_scalar_t<descriptor_t>, data_t>;
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
hjabird marked this conversation as resolved.
Show resolved Hide resolved

// compute the range of a reinterpreted buffer
template <typename In, typename Out>
std::size_t reinterpret_range(std::size_t size) {
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/mkl/dft/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout) {
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
sycl::buffer<data_type, 1> &inout_im) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
Expand Down Expand Up @@ -114,12 +115,12 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout,
}

//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template <typename descriptor_type, typename data_type>
template <typename descriptor_type, typename data_type,
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
const std::vector<sycl::event> &dependencies = {}) {
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
"unexpected type for data_type");

using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
reinterpret_cast<scalar_type *>(inout_im), dependencies);
Expand All @@ -133,7 +134,6 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *
"unexpected type for input_type");
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
"unexpected type for output_type");

using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;
return get_commit(desc)->forward_op_cc(desc, reinterpret_cast<fwd_type *>(in),
Expand Down
Loading