Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT][MKLGPU] Use FWD/BWD_STRIDES #514

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {

std::int64_t get_plan_workspace_size_bytes(cufftHandle handle) {
std::size_t size = 0;
cufftGetSize(*plans[0], &size);
cufftGetSize(handle, &size);
std::int64_t padded_size = static_cast<int64_t>(size);
return padded_size;
}
Expand Down
5 changes: 4 additions & 1 deletion src/dft/backends/mklgpu/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ namespace detail {
template <dft::detail::precision prec, dft::detail::domain dom, typename... ArgTs>
inline auto compute_backward(dft::detail::descriptor<prec, dom> &desc, ArgTs &&... args) {
using mklgpu_desc_t = dft::descriptor<to_mklgpu(prec), to_mklgpu(dom)>;
using desc_shptr_t = std::shared_ptr<mklgpu_desc_t>;
using handle_t = std::pair<desc_shptr_t, desc_shptr_t>;
auto commit_handle = dft::detail::get_commit(desc);
if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) {
throw mkl::invalid_argument("DFT", "compute_backward",
"DFT descriptor has not been commited for MKLGPU");
}
auto mklgpu_desc = reinterpret_cast<mklgpu_desc_t *>(commit_handle->get_handle());
auto handle = reinterpret_cast<handle_t *>(commit_handle->get_handle());
auto mklgpu_desc = handle->second; // Second because backward DFT.
int commit_status{ DFTI_UNCOMMITTED };
mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
Expand Down
86 changes: 72 additions & 14 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
// MKLGPU header
#include "oneapi/mkl/dfti.hpp"

// MKL 2024.1 deprecates input/output strides.
#include "mkl_version.h"
#if INTEL_MKL_VERSION < 20240001
#error MKLGPU requires oneMKL 2024.1 or later
#endif

/**
Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface
of this OneMKL open-source library. Consequently, the types under dft::TYPE are closed-source oneMKL types,
Expand All @@ -53,14 +59,22 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
// Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
static constexpr dft::domain mklgpu_dom = to_mklgpu(dom);

// A pair of descriptors are needed because of the [[deprecated]]IN/OUTPUT_STRIDES vs F/BWD_STRIDES API.
// Of the pair [0] is fwd DFT, [1] is backward DFT. If possible, the pointers refer to the same desciptor.
// Both pointers must be valid.
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
using mklgpu_descriptor_t = dft::descriptor<mklgpu_prec, mklgpu_dom>;
using descriptor_shptr_t = std::shared_ptr<mklgpu_descriptor_t>;
using handle_t = std::pair<descriptor_shptr_t, descriptor_shptr_t>;

using scalar_type = typename dft::detail::commit_impl<prec, dom>::scalar_type;

public:
mklgpu_commit(sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu,
config_values),
handle(config_values.dimensions) {
handle(std::make_shared<mklgpu_descriptor_t>(config_values.dimensions), nullptr) {
handle.second = handle.first; // Make sure the bwd pointer is valid.
// MKLGPU does not throw an informative exception for the following:
if constexpr (prec == dft::detail::precision::DOUBLE) {
if (!queue.get_device().has(sycl::aspect::fp64)) {
Expand All @@ -75,13 +89,43 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
oneapi::mkl::dft::detail::external_workspace_helper<prec, dom>(
config_values.workspace_placement ==
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
set_value(handle, config_values);

// A separate descriptor for each direction may not be required.
bool one_descriptor = config_values.input_strides == config_values.output_strides;
bool forward_good = true;
// Make sure that second is always pointing to something new if this is a recommit.
handle.second = handle.first;

// Generate forward DFT descriptor.
set_value(*handle.first, config_values, true);
try {
handle.commit(this->get_queue());
handle.first->commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
// Catching the real Intel oneMKL exception causes headaches with naming
forward_good = false;
if (one_descriptor) {
throw mkl::exception("dft/backends/mklgpu"
"commit",
mkl_exception.what());
}
}

// Generate backward DFT descriptor only if required.
if (!one_descriptor) {
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
set_value(*handle.second, config_values, false);
try {
handle.second->commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
if (!forward_good) {
throw mkl::exception("dft/backends/mklgpu"
"commit",
mkl_exception.what());
}
}
}
}

Expand All @@ -93,12 +137,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {

virtual void set_workspace(scalar_type* usm_workspace) override {
this->external_workspace_helper_.set_workspace_throw(*this, usm_workspace);
handle.set_workspace(usm_workspace);
handle.first->set_workspace(usm_workspace);
if (handle.first != handle.second) {
handle.second->set_workspace(usm_workspace);
}
}

virtual void set_workspace(sycl::buffer<scalar_type>& buffer_workspace) override {
this->external_workspace_helper_.set_workspace_throw(*this, buffer_workspace);
handle.set_workspace(buffer_workspace);
handle.first->set_workspace(buffer_workspace);
if (handle.first != handle.second) {
handle.second->set_workspace(buffer_workspace);
}
}

#define BACKEND mklgpu
Expand All @@ -107,9 +157,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {

private:
// The native MKLGPU class.
mklgpu_descriptor_t handle;
handle_t handle;

void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config) {
void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config,
bool assume_fwd_dft) {
using onemkl_param = dft::detail::config_param;
using backend_param = dft::config_param;

Expand All @@ -134,8 +185,14 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data());
desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data());
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
Expand All @@ -158,9 +215,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {

// This is called by the workspace_helper, and is not part of the user API.
virtual std::int64_t get_workspace_external_bytes_impl() override {
std::size_t workspaceSize = 0;
handle.get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSize);
return static_cast<std::int64_t>(workspaceSize);
std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
handle.first->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeFwd);
handle.second->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeBwd);
return static_cast<std::int64_t>(std::max(workspaceSizeFwd, workspaceSizeFwd));
}
};
} // namespace detail
Expand Down
5 changes: 4 additions & 1 deletion src/dft/backends/mklgpu/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ namespace detail {
template <dft::detail::precision prec, dft::detail::domain dom, typename... ArgTs>
inline auto compute_forward(dft::detail::descriptor<prec, dom> &desc, ArgTs &&... args) {
using mklgpu_desc_t = dft::descriptor<to_mklgpu(prec), to_mklgpu(dom)>;
using desc_shptr_t = std::shared_ptr<mklgpu_desc_t>;
using handle_t = std::pair<desc_shptr_t, desc_shptr_t>;
auto commit_handle = dft::detail::get_commit(desc);
if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) {
throw mkl::invalid_argument("DFT", "compute_forward",
"DFT descriptor has not been commited for MKLGPU");
}
auto mklgpu_desc = reinterpret_cast<mklgpu_desc_t *>(commit_handle->get_handle());
auto handle = reinterpret_cast<handle_t *>(commit_handle->get_handle());
auto mklgpu_desc = handle->first; // First because forward DFT.
int commit_status{ DFTI_UNCOMMITTED };
mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
Expand Down