From 128ef68cd2ed35c6543ca163a8d79bd69b678748 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Wed, 7 Aug 2024 12:06:00 +0100 Subject: [PATCH] [DFT] Improvements to FWD/BWD_STRIDE API * Addresses spec deviations for API implementation * Improves run-time checks * Improves comments --- src/dft/backends/cufft/commit.cpp | 2 +- src/dft/backends/descriptor.cpp | 8 +- src/dft/backends/mklcpu/descriptor.cpp | 8 +- src/dft/backends/mklgpu/descriptor.cpp | 8 +- src/dft/backends/rocfft/commit.cpp | 44 +++++---- src/dft/descriptor.cxx | 42 +++++--- .../dft/include/compute_out_of_place.hpp | 97 +++++++++++-------- .../dft/source/descriptor_tests.cpp | 46 +++++---- 8 files changed, 143 insertions(+), 112 deletions(-) diff --git a/src/dft/backends/cufft/commit.cpp b/src/dft/backends/cufft/commit.cpp index faf4332c0..8dd9b225b 100644 --- a/src/dft/backends/cufft/commit.cpp +++ b/src/dft/backends/cufft/commit.cpp @@ -173,7 +173,7 @@ class cufft_commit final : public dft::detail::commit_impl { if (a_min - stride_vecs.vec_a.begin() != b_min - stride_vecs.vec_b.begin()) { throw mkl::unimplemented( "dft/backends/cufft", __FUNCTION__, - "cufft requires that if ordered by stride length, the order of strides is the same for input and output strides!"); + "cufft requires that if ordered by stride length, the order of strides is the same for input/output or fwd/bwd strides!"); } } const int a_stride = static_cast(*a_min); diff --git a/src/dft/backends/descriptor.cpp b/src/dft/backends/descriptor.cpp index aa4cded9c..c6f6884f8 100644 --- a/src/dft/backends/descriptor.cpp +++ b/src/dft/backends/descriptor.cpp @@ -22,9 +22,7 @@ #include "../descriptor.cxx" -namespace oneapi { -namespace mkl { -namespace dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(sycl::queue &queue) { @@ -41,6 +39,4 @@ template void descriptor::commit(sycl::queue &) template void descriptor::commit(sycl::queue &); template void descriptor::commit(sycl::queue &); -} //namespace dft -} //namespace mkl -} //namespace oneapi +} //namespace oneapi::mkl::dft::detail diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp index 2bb0e2835..a72fdcfc3 100644 --- a/src/dft/backends/mklcpu/descriptor.cpp +++ b/src/dft/backends/mklcpu/descriptor.cpp @@ -22,9 +22,7 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" -namespace oneapi { -namespace mkl { -namespace dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(backend_selector selector) { @@ -46,6 +44,4 @@ template void descriptor::commit( template void descriptor::commit( backend_selector); -} //namespace dft -} //namespace mkl -} //namespace oneapi +} //namespace oneapi::mkl::dft::detail diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp index d2d2fee7a..7f7f0bf70 100644 --- a/src/dft/backends/mklgpu/descriptor.cpp +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -22,9 +22,7 @@ #include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" -namespace oneapi { -namespace mkl { -namespace dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(backend_selector selector) { @@ -46,6 +44,4 @@ template void descriptor::commit( template void descriptor::commit( backend_selector); -} //namespace dft -} //namespace mkl -} //namespace oneapi +} //namespace oneapi::mkl::dft::detail diff --git a/src/dft/backends/rocfft/commit.cpp b/src/dft/backends/rocfft/commit.cpp index db5a7f965..ea0ba18a6 100644 --- a/src/dft/backends/rocfft/commit.cpp +++ b/src/dft/backends/rocfft/commit.cpp @@ -289,26 +289,34 @@ class rocfft_commit final : public dft::detail::commit_impl { if (dom == dft::domain::REAL) { lengths_cplx[0] = lengths_cplx[0] / 2 + 1; } - // When creating real-complex descriptions, the strides will always be wrong for one of the directions. + // When creating real-complex descriptions with INPUT/OUTPUT_STRIDES, + // the strides will always be wrong for one of the directions. // This is because the least significant dimension is symmetric. // If the strides are invalid (too small to fit) then just don't bother creating the plan. - const bool vec_a_valid_as_reals = - dimensions == 1 || - (lengths_cplx[stride_a_indices[0]] <= stride_vecs.vec_a[stride_a_indices[1]] && - (dimensions == 2 || - lengths_cplx[stride_a_indices[0]] * lengths_cplx[stride_a_indices[1]] <= - stride_vecs.vec_a[stride_a_indices[2]])); - const bool vec_b_valid_as_reals = - dimensions == 1 || - (lengths_cplx[stride_b_indices[0]] <= stride_vecs.vec_b[stride_b_indices[1]] && - (dimensions == 2 || - lengths_cplx[stride_b_indices[0]] * lengths_cplx[stride_b_indices[1]] <= - stride_vecs.vec_b[stride_b_indices[2]])); - // Test if the stride vector being used as the fwd domain for each direction has valid strides for that use. - bool valid_forward = - stride_vecs.fwd_in == stride_vecs.vec_a && vec_a_valid_as_reals || vec_b_valid_as_reals; - bool valid_backward = stride_vecs.bwd_out == stride_vecs.vec_a && vec_a_valid_as_reals || - vec_b_valid_as_reals; + auto are_strides_smaller_than_lengths = [=](auto& svec, auto& sindices, + auto& domain_lengths) { + return dimensions == 1 || + (domain_lengths[sindices[0]] <= svec[sindices[1]] && + (dimensions == 2 || domain_lengths[sindices[0]] * domain_lengths[sindices[1]] <= + svec[sindices[2]])); + }; + + const bool vec_a_valid_as_fwd_domain = + are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths); + const bool vec_b_valid_as_fwd_domain = + are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths); + const bool vec_a_valid_as_bwd_domain = + are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths_cplx); + const bool vec_b_valid_as_bwd_domain = + are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths_cplx); + + // Test if the stride vector being used as the fwd/bwd domain for each direction has valid strides for that use. + bool valid_forward = (stride_vecs.fwd_in == stride_vecs.vec_a && + vec_a_valid_as_fwd_domain && vec_b_valid_as_bwd_domain) || + (vec_b_valid_as_fwd_domain && vec_a_valid_as_bwd_domain); + bool valid_backward = (stride_vecs.bwd_in == stride_vecs.vec_a && + vec_a_valid_as_bwd_domain && vec_b_valid_as_fwd_domain) || + (vec_b_valid_as_bwd_domain && vec_a_valid_as_fwd_domain); if (!valid_forward && !valid_backward) { throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Invalid strides."); diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx index a9acd3b9e..aa0d9d70e 100644 --- a/src/dft/descriptor.cxx +++ b/src/dft/descriptor.cxx @@ -28,21 +28,28 @@ namespace mkl { namespace dft { namespace detail { -// Compute the default strides. Modifies real_strides and complex_strides arguments. +// Compute the default strides. Modifies real_strides and complex_strides arguments +template inline void compute_default_strides(const std::vector& dimensions, std::vector& fwd_strides, std::vector& bwd_strides) { - auto rank = dimensions.size(); - std::vector strides(rank + 1, 1); - for (auto i = rank - 1; i > 0; --i) { - strides[i] = strides[i + 1] * dimensions[i]; + const auto rank = dimensions.size(); + fwd_strides = std::vector(rank + 1, 1); + fwd_strides[0] = 0; + bwd_strides = fwd_strides; + if (rank == 1) { + return; + } + + bwd_strides[rank - 1] = + dom == domain::COMPLEX ? dimensions[rank - 1] : (dimensions[rank - 1] / 2) + 1; + fwd_strides[rank - 1] = + dom == domain::COMPLEX ? dimensions[rank - 1] : 2 * bwd_strides[rank - 1]; + for (auto i = rank - 1; i > 1; --i) { + // Can't start at rank - 2 with unsigned type and minimum value of rank being 1. + bwd_strides[i - 1] = bwd_strides[i] * dimensions[i - 1]; + fwd_strides[i - 1] = fwd_strides[i] * dimensions[i - 1]; } - strides[0] = 0; - // Fwd/Bwd strides and Input/Output strides being the same by default means - // that we don't have to specify if we default to using fwd/bwd strides or - // input/output strides. - bwd_strides = strides; - fwd_strides = std::move(strides); } template @@ -69,12 +76,15 @@ void descriptor::set_value(config_param param, ...) { case config_param::PRECISION: throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); break; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" case config_param::INPUT_STRIDES: detail::set_value(values_, va_arg(vl, std::int64_t*)); break; case config_param::OUTPUT_STRIDES: detail::set_value(values_, va_arg(vl, std::int64_t*)); break; +#pragma clang diagnostic pop case config_param::FWD_STRIDES: detail::set_value(values_, va_arg(vl, std::int64_t*)); break; @@ -150,10 +160,9 @@ descriptor::descriptor(std::vector dimensions) { "Invalid dimension value (negative or 0)."); } } - compute_default_strides(dimensions, values_.fwd_strides, values_.bwd_strides); - // Assume forward transform. - values_.input_strides = values_.fwd_strides; - values_.output_strides = values_.bwd_strides; + compute_default_strides(dimensions, values_.fwd_strides, values_.bwd_strides); + values_.input_strides = std::vector(dimensions.size() + 1, 0); + values_.output_strides = std::vector(dimensions.size() + 1, 0); values_.bwd_scale = real_t(1.0); values_.fwd_scale = real_t(1.0); values_.number_of_transforms = 1; @@ -220,6 +229,8 @@ void descriptor::get_value(config_param param, ...) const { *va_arg(vl, config_value*) = values_.conj_even_storage; break; case config_param::PLACEMENT: *va_arg(vl, config_value*) = values_.placement; break; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" case config_param::INPUT_STRIDES: std::copy(values_.input_strides.begin(), values_.input_strides.end(), va_arg(vl, std::int64_t*)); @@ -228,6 +239,7 @@ void descriptor::get_value(config_param param, ...) const { std::copy(values_.output_strides.begin(), values_.output_strides.end(), va_arg(vl, std::int64_t*)); break; +#pragma clang diagnostic pop case config_param::FWD_STRIDES: std::copy(values_.fwd_strides.begin(), values_.fwd_strides.end(), va_arg(vl, std::int64_t*)); diff --git a/tests/unit_tests/dft/include/compute_out_of_place.hpp b/tests/unit_tests/dft/include/compute_out_of_place.hpp index df5e1e323..bcfd09dda 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place.hpp @@ -29,11 +29,26 @@ int DFT_Test::test_out_of_place_buffer() { return test_skipped; } + descriptor_t descriptor{ sizes }; + auto strides_fwd_cpy = strides_fwd; + auto strides_bwd_cpy = strides_bwd; + if (strides_fwd_cpy.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data()); + } + else { + strides_fwd_cpy.resize(sizes.size() + 1); + descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data()); + } + if (strides_bwd_cpy.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data()); + } + else { + strides_bwd_cpy.resize(sizes.size() + 1); + descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data()); + } auto [forward_distance, backward_distance] = - get_default_distances(sizes, strides_fwd, strides_bwd); + get_default_distances(sizes, strides_fwd_cpy, strides_bwd_cpy); auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); - - descriptor_t descriptor{ sizes }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); if constexpr (domain == oneapi::mkl::dft::domain::REAL) { @@ -45,22 +60,12 @@ int DFT_Test::test_out_of_place_buffer() { descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); - if (strides_fwd.size()) { - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data()); - } - if (strides_bwd.size()) { - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data()); - } - else if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data()); - } commit_descriptor(descriptor, sycl_queue); std::vector fwd_data( - strided_copy(input, sizes, strides_fwd, batches, forward_distance)); + strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance)); auto tmp = std::vector( - cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), 0); + cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)), 0); { sycl::buffer fwd_buf{ fwd_data }; sycl::buffer bwd_buf{ tmp }; @@ -74,7 +79,7 @@ int DFT_Test::test_out_of_place_buffer() { for (std::int64_t i = 0; i < batches; i++) { EXPECT_TRUE(check_equal_strided( bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes, - strides_bwd, abs_error_margin, rel_error_margin, std::cout)); + strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout)); } } @@ -88,9 +93,9 @@ int DFT_Test::test_out_of_place_buffer() { [this](auto &x) { x *= static_cast(forward_elements); }); for (std::int64_t i = 0; i < batches; i++) { - EXPECT_TRUE(check_equal_strided(fwd_data.data() + forward_distance * i, - input.data() + ref_distance * i, sizes, strides_fwd, - abs_error_margin, rel_error_margin, std::cout)); + EXPECT_TRUE(check_equal_strided( + fwd_data.data() + forward_distance * i, input.data() + ref_distance * i, sizes, + strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout)); } return !::testing::Test::HasFailure(); @@ -103,10 +108,34 @@ int DFT_Test::test_out_of_place_USM() { } const std::vector no_dependencies; - auto [forward_distance, backward_distance] = - get_default_distances(sizes, strides_fwd, strides_bwd); - descriptor_t descriptor{ sizes }; + auto strides_fwd_cpy = strides_fwd; + auto strides_bwd_cpy = strides_bwd; + if (strides_fwd_cpy.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data()); + } + else { + strides_fwd_cpy.resize(sizes.size() + 1); + descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data()); + } + if (strides_bwd_cpy.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data()); + } + else { + strides_bwd_cpy.resize(sizes.size() + 1); + descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data()); + } + auto [forward_distance, backward_distance] = + get_default_distances(sizes, strides_fwd_cpy, strides_bwd_cpy); + auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); + descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, + oneapi::mkl::dft::config_value::NOT_INPLACE); + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, + oneapi::mkl::dft::config_value::COMPLEX_COMPLEX); + descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, + oneapi::mkl::dft::config_value::CCE_FORMAT); + } descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); if constexpr (domain == oneapi::mkl::dft::domain::REAL) { @@ -118,36 +147,26 @@ int DFT_Test::test_out_of_place_USM() { descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); - if (strides_fwd.size()) { - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data()); - } - if (strides_bwd.size()) { - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data()); - } - else if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data()); - } commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); auto ua_output = usm_allocator_t(cxt, *dev); std::vector fwd( - strided_copy(input, sizes, strides_fwd, batches, forward_distance, ua_input), ua_input); + strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance, ua_input), ua_input); std::vector bwd( - cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), ua_output); + cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)), + ua_output); oneapi::mkl::dft::compute_forward( descriptor, fwd.data(), bwd.data(), no_dependencies) .wait_and_throw(); auto bwd_ptr = &bwd[0]; - auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); for (std::int64_t i = 0; i < batches; i++) { EXPECT_TRUE(check_equal_strided( bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes, - strides_bwd, abs_error_margin, rel_error_margin, std::cout)); + strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout)); } oneapi::mkl::dft::compute_backward, FwdOutputType, @@ -160,9 +179,9 @@ int DFT_Test::test_out_of_place_USM() { [this](auto &x) { x *= static_cast(forward_elements); }); for (std::int64_t i = 0; i < batches; i++) { - EXPECT_TRUE(check_equal_strided(fwd.data() + forward_distance * i, - input.data() + ref_distance * i, sizes, strides_fwd, - abs_error_margin, rel_error_margin, std::cout)); + EXPECT_TRUE(check_equal_strided( + fwd.data() + forward_distance * i, input.data() + ref_distance * i, sizes, + strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout)); } return !::testing::Test::HasFailure(); diff --git a/tests/unit_tests/dft/source/descriptor_tests.cpp b/tests/unit_tests/dft/source/descriptor_tests.cpp index a420eb1e2..155160149 100644 --- a/tests/unit_tests/dft/source/descriptor_tests.cpp +++ b/tests/unit_tests/dft/source/descriptor_tests.cpp @@ -140,7 +140,7 @@ static void set_and_get_io_strides() { descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_before_set.data()); - EXPECT_EQ(default_strides_value, input_strides_before_set); + EXPECT_EQ(std::vector(strides_size, 0), input_strides_before_set); descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_value.data()); descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_after_set.data()); @@ -154,7 +154,7 @@ static void set_and_get_io_strides() { std::vector output_strides_after_set(strides_size); descriptor.get_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, output_strides_before_set.data()); - EXPECT_EQ(default_strides_value, output_strides_before_set); + EXPECT_EQ(std::vector(strides_size, 0), output_strides_before_set); descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, output_strides_value.data()); descriptor.get_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, @@ -176,23 +176,27 @@ static void set_and_get_fwd_bwd_strides() { const std::int64_t default_stride_d2 = default_3d_lengths[2]; const std::int64_t default_stride_d3 = 1; - std::vector default_strides_value{ 0, default_stride_d1, default_stride_d2, - default_stride_d3 }; - - std::vector fwd_strides_value; - std::vector bwd_strides_value; + std::vector fwd_strides_default_value; + std::vector bwd_strides_default_value; if constexpr (domain == oneapi::mkl::dft::domain::COMPLEX) { - fwd_strides_value = { 50, default_stride_d1 * 2, default_stride_d2 * 2, - default_stride_d3 * 2 }; - bwd_strides_value = { 50, default_stride_d1 * 2, default_stride_d2 * 2, - default_stride_d3 * 2 }; + fwd_strides_default_value = { 0, default_stride_d1, default_stride_d2, default_stride_d3 }; + bwd_strides_default_value = { 0, default_stride_d1, default_stride_d2, default_stride_d3 }; } else { - fwd_strides_value = { 0, default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1) * 2, - (default_3d_lengths[2] / 2 + 1) * 2, 1 }; - bwd_strides_value = { 0, default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1), - (default_3d_lengths[2] / 2 + 1), 1 }; + fwd_strides_default_value = { 0, + default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1) * 2, + (default_3d_lengths[2] / 2 + 1) * 2, 1 }; + bwd_strides_default_value = { 0, default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1), + (default_3d_lengths[2] / 2 + 1), 1 }; + } + auto fwd_strides_new_value = fwd_strides_default_value; + auto bwd_strides_new_value = bwd_strides_default_value; + for (auto i = 0UL; i < fwd_strides_new_value.size(); ++i) { + fwd_strides_new_value[i] *= 4; + bwd_strides_new_value[i] *= 4; } + fwd_strides_new_value[0] = 50; + bwd_strides_new_value[0] = 50; std::vector fwd_strides_before_set(strides_size); std::vector fwd_strides_after_set(strides_size); @@ -201,14 +205,14 @@ static void set_and_get_fwd_bwd_strides() { descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_before_set.data()); - EXPECT_EQ(default_strides_value, fwd_strides_before_set); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_value.data()); + EXPECT_EQ(fwd_strides_default_value, fwd_strides_before_set); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_new_value.data()); descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_after_set.data()); descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_after_set.data()); descriptor.get_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, output_strides_after_set.data()); - EXPECT_EQ(fwd_strides_value, fwd_strides_after_set); + EXPECT_EQ(fwd_strides_new_value, fwd_strides_after_set); EXPECT_EQ(std::vector(strides_size, 0), input_strides_after_set); EXPECT_EQ(std::vector(strides_size, 0), output_strides_after_set); @@ -216,10 +220,10 @@ static void set_and_get_fwd_bwd_strides() { std::vector bwd_strides_after_set(strides_size); descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_before_set.data()); - EXPECT_EQ(default_strides_value, bwd_strides_before_set); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_value.data()); + EXPECT_EQ(bwd_strides_default_value, bwd_strides_before_set); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_new_value.data()); descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_after_set.data()); - EXPECT_EQ(bwd_strides_value, bwd_strides_after_set); + EXPECT_EQ(bwd_strides_new_value, bwd_strides_after_set); } #pragma clang diagnostic pop