From 62184efbf21800f732825eb04caeff8598d3c047 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Tue, 28 May 2024 16:18:55 +0100 Subject: [PATCH] [DFT] Add static_assert to check types passed into compute_???ward * static assert that the types are real or complex as required --- include/oneapi/mkl/dft/backward.hpp | 24 +++++++++++++++++++- include/oneapi/mkl/dft/detail/types_impl.hpp | 8 ++++++- include/oneapi/mkl/dft/forward.hpp | 24 +++++++++++++++++++- 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index b63dabb28..094ce0eef 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ void compute_backward(descriptor_type &desc, sycl::buffer &inout_r sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); + static_assert(!detail::is_complex_arg::value, "expected real type for data_type"); using scalar_type = typename detail::descriptor_info::scalar_type; auto type_corrected_inout_re = inout_re.template reinterpret( @@ -69,6 +70,13 @@ void compute_backward(descriptor_type &desc, sycl::buffer &in, using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; + + // If the DFT is COMPLEX, the input and output types are expected to be complex. + static_assert( + !std::is_same_v || (detail::is_complex_arg::value && + detail::is_complex_arg::value), + "expected std::complex input_type and output_type"); + auto type_corrected_in = in.template reinterpret( detail::reinterpret_range(in.size())); auto type_corrected_out = out.template reinterpret( @@ -85,6 +93,9 @@ void compute_backward(descriptor_type &desc, sycl::buffer &in_re, "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); + static_assert( + !detail::is_complex_arg::value && !detail::is_complex_arg::value, + "expected input_type and output_type to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; auto type_corrected_in_re = in_re.template reinterpret( @@ -119,6 +130,7 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_ty const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); + static_assert(!detail::is_complex_arg::value, "data_type is expected to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->backward_ip_rr(desc, reinterpret_cast(inout_re), @@ -137,6 +149,13 @@ sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; + + // If the DFT is COMPLEX, the input and output types are expected to be complex. + static_assert( + !std::is_same_v || (detail::is_complex_arg::value && + detail::is_complex_arg::value), + "expected std::complex input_type and output_type"); + return get_commit(desc)->backward_op_cc(desc, reinterpret_cast(in), reinterpret_cast(out), dependencies); } @@ -150,6 +169,9 @@ sycl::event compute_backward(descriptor_type &desc, input_type *in_re, input_typ "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); + static_assert( + !detail::is_complex_arg::value && !detail::is_complex_arg::value, + "expected input_type and output_type to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->backward_op_rr(desc, reinterpret_cast(in_re), diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index e41a71469..cd9d3df95 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,6 +97,12 @@ using valid_compute_arg = typename std::bool_constant< (std::is_same_v::scalar_type, double> && is_one_of>::value)>; +template +class is_complex_arg : public std::false_type {}; + +template +class is_complex_arg> : public std::true_type {}; + // compute the range of a reinterpreted buffer template std::size_t reinterpret_range(std::size_t size) { diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 3fcd60b1c..50a8d61ea 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,6 +50,7 @@ void compute_forward(descriptor_type &desc, sycl::buffer &inout_re sycl::buffer &inout_im) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); + static_assert(!detail::is_complex_arg::value, "expected real type for data_type"); using scalar_type = typename detail::descriptor_info::scalar_type; auto type_corrected_inout_re = inout_re.template reinterpret( @@ -70,6 +71,13 @@ void compute_forward(descriptor_type &desc, sycl::buffer &in, using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; + + // If the DFT is COMPLEX, the input and output types are expected to be complex. + static_assert( + !std::is_same_v || (detail::is_complex_arg::value && + detail::is_complex_arg::value), + "expected std::complex input_type and output_type"); + auto type_corrected_in = in.template reinterpret( detail::reinterpret_range(in.size())); auto type_corrected_out = out.template reinterpret( @@ -86,6 +94,9 @@ void compute_forward(descriptor_type &desc, sycl::buffer &in_re, "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); + static_assert( + !detail::is_complex_arg::value && !detail::is_complex_arg::value, + "expected input_type and output_type to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; auto type_corrected_in_re = in_re.template reinterpret( @@ -119,6 +130,7 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_typ const std::vector &dependencies = {}) { static_assert(detail::valid_compute_arg::value, "unexpected type for data_type"); + static_assert(!detail::is_complex_arg::value, "data_type is expected to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast(inout_re), @@ -136,6 +148,13 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type * using fwd_type = typename detail::descriptor_info::forward_type; using bwd_type = typename detail::descriptor_info::backward_type; + + // If the DFT is COMPLEX, the input and output types are expected to be complex. + static_assert( + !std::is_same_v || (detail::is_complex_arg::value && + detail::is_complex_arg::value), + "expected std::complex input_type and output_type"); + return get_commit(desc)->forward_op_cc(desc, reinterpret_cast(in), reinterpret_cast(out), dependencies); } @@ -149,6 +168,9 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in_re, input_type "unexpected type for input_type"); static_assert(detail::valid_compute_arg::value, "unexpected type for output_type"); + static_assert( + !detail::is_complex_arg::value && !detail::is_complex_arg::value, + "expected input_type and output_type to be real"); using scalar_type = typename detail::descriptor_info::scalar_type; return get_commit(desc)->forward_op_rr(desc, reinterpret_cast(in_re),