Skip to content

Commit

Permalink
[rocFFT] Add checks for rocFFT version due to internal bug
Browse files Browse the repository at this point in the history
Due to rocFFt internal bug ROCm/rocFFT#507
Add cmake version checks based upon rocFFT version.
Add exception to deal with faulty cases for affected rocFFT version.
RocFFT versions taken from https://github.com/ROCm/rocFFT/blob/develop/CHANGELOG.md
  • Loading branch information
s-Nick committed Aug 30, 2024
1 parent 092054b commit 54e0cf9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/dft/backends/rocfft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ find_package(HIP REQUIRED)
# Require the minimum rocFFT version matching with ROCm 5.4.3.
find_package(rocfft 1.0.21 REQUIRED)

if (${rocfft_VERSION_MAJOR} EQUAL "1" AND ${rocfft_VERSION_MINOR} EQUAL "0"
AND ((${rocfft_VERSION_PATCH} GREATER "22")
AND (${rocfft_VERSION_PATCH} LESS "31") ))
message(WARNING "Due to a bug in rocFFT some tests fail with the version in\
use. If possible use a version greater of 1.0.30 or less of 1.0.23.
Current rocFFT version ${rocfft_VERSION}")
endif()

target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocfft)

# Allow to compile for different ROCm versions. See the README for the supported
Expand All @@ -62,6 +70,7 @@ find_path(
NO_DEFAULT_PATH
REQUIRED
)

target_include_directories(${LIB_OBJ} PRIVATE ${rocfft_EXTRA_INCLUDE_DIR})

target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL)
Expand Down
23 changes: 23 additions & 0 deletions src/dft/backends/rocfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "rocfft_handle.hpp"

#include <rocfft.h>
#include <rocfft-version.h>
#include <hip/hip_runtime_api.h>

namespace oneapi::mkl::dft::rocfft {
Expand Down Expand Up @@ -259,6 +260,28 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
std::reverse(stride_vecs.vec_b.begin(), stride_vecs.vec_b.end());
stride_vecs.vec_b.pop_back(); // Offset is not included.

// This workaround is needed due to a confirmed issue in rocFFT from version
// 1.0.23 to 1.0.30. Those rocFFT version correspond to rocm version from
// 5.6.0 to 6.3.0.
// Link to rocFFT issue: https://github.com/ROCm/rocFFT/issues/507
if constexpr (rocfft_version_major == 1 && rocfft_version_minor == 0 &&
(rocfft_version_patch > 22 && rocfft_version_patch < 31)) {
if (dom == dft::domain::COMPLEX && dimensions > 2) {
auto stride_checker = [&](const auto& a, const auto& b) {
for (ulong i = 0; i < dimensions; ++i) {
if (a[i] != b[i])
return false;
}
return true;
};
std::printf("hello\n");
if (!stride_checker(stride_vecs.vec_a, stride_vecs.vec_b))
throw oneapi::mkl::unimplemented(
"DFT", func,
"due to a bug in rocfft version in use, it requires fwd and bwd stride to be the same for COMPLEX out_of_place computations");
}
}

rocfft_plan_description plan_desc_fwd, plan_desc_bwd; // Can't reuse with ROCm 6 due to bug.
if (rocfft_plan_description_create(&plan_desc_fwd) != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
Expand Down

0 comments on commit 54e0cf9

Please sign in to comment.