Skip to content

Commit

Permalink
Fix issues in original commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hjabird committed Jun 13, 2024
1 parent 8220922 commit 6b76b1d
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@

// MKL 2024.1 deprecates input/output strides.
#include "mkl_version.h"
namespace oneapi::mkl::dft::mklgpu::detail {
constexpr bool mklgpu_use_forward_backward_strides_api = INTEL_MKL_VERSION >= 20240001;
}
#define MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API INTEL_MKL_VERSION >= 20240001

/**
Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface
Expand Down Expand Up @@ -91,18 +89,22 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
// Generate forward DFT descriptor.
set_value(*handle.first, config_values, true);
std::optional<mkl::exception> fwd_exception;
try {
handle.first->commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
fwd_exception = mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
}

// Generate backward DFT descriptor only if required.
if (config_values.input_strides == config_values.output_strides) {
// Required if second != first before a recommit.
handle.second = handle.first;
if (fwd_exception) {
throw *fwd_exception;
}
}
else {
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
Expand All @@ -112,7 +114,9 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
if (fwd_exception) {
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
}
}
}
}
Expand Down Expand Up @@ -148,7 +152,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
handle_t handle;

void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config,
bool assume_fwd_dft) {
[[maybe_unused]] bool assume_fwd_dft) {
using onemkl_param = dft::detail::config_param;
using backend_param = dft::config_param;

Expand All @@ -173,22 +177,21 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if constexpr (mklgpu_use_forward_backward_strides_api) {
// Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API.
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
#if MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API
// Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API.
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
// Support for Intel oneMKL older than 2024.1
desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data());
desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data());
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
#else // !MKLGPU_USE_FORWARD_BACKWARD_STRIDES_API
// Support for Intel oneMKL older than 2024.1
desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data());
desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data());
#endif
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
Expand Down

0 comments on commit 6b76b1d

Please sign in to comment.