From 6b76b1ddc9afeb4347f66ef1d650254ab9ecda7d Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Thu, 13 Jun 2024 17:43:39 +0100 Subject: [PATCH] Fix issues in original commit --- src/dft/backends/mklgpu/commit.cpp | 41 ++++++++++++++++-------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp index 7decf4bfb..a7343ec4e 100644 --- a/src/dft/backends/mklgpu/commit.cpp +++ b/src/dft/backends/mklgpu/commit.cpp @@ -39,9 +39,7 @@ // MKL 2024.1 deprecates input/output strides. #include "mkl_version.h" -namespace oneapi::mkl::dft::mklgpu::detail { -constexpr bool mklgpu_use_forward_backward_strides_api = INTEL_MKL_VERSION >= 20240001; -} +#define MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API INTEL_MKL_VERSION >= 20240001 /** Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface @@ -91,18 +89,22 @@ class mklgpu_commit final : public dft::detail::commit_impl { oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); // Generate forward DFT descriptor. set_value(*handle.first, config_values, true); + std::optional fwd_exception; try { handle.first->commit(this->get_queue()); } catch (const std::exception& mkl_exception) { // Catching the real Intel oneMKL exception causes headaches with naming. - throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + fwd_exception = mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); } // Generate backward DFT descriptor only if required. if (config_values.input_strides == config_values.output_strides) { // Required if second != first before a recommit. handle.second = handle.first; + if (fwd_exception) { + throw *fwd_exception; + } } else { handle.second = std::make_shared(config_values.dimensions); @@ -112,7 +114,9 @@ class mklgpu_commit final : public dft::detail::commit_impl { } catch (const std::exception& mkl_exception) { // Catching the real Intel oneMKL exception causes headaches with naming. - throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + if (fwd_exception) { + throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + } } } } @@ -148,7 +152,7 @@ class mklgpu_commit final : public dft::detail::commit_impl { handle_t handle; void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values& config, - bool assume_fwd_dft) { + [[maybe_unused]] bool assume_fwd_dft) { using onemkl_param = dft::detail::config_param; using backend_param = dft::config_param; @@ -173,22 +177,21 @@ class mklgpu_commit final : public dft::detail::commit_impl { throw mkl::unimplemented("dft/backends/mklgpu", "commit", "MKLGPU does not support nonzero offsets."); } - if constexpr (mklgpu_use_forward_backward_strides_api) { - // Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API. - if (assume_fwd_dft) { - desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); - } - else { - desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); - } +#if MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API + // Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API. + if (assume_fwd_dft) { + desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); } else { - // Support for Intel oneMKL older than 2024.1 - desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); + desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); } +#else // !MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API + // Support for Intel oneMKL older than 2024.1 + desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); + desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); +#endif desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist); if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {