Skip to content

Commit

Permalink
[DFT] Improvements to FWD/BWD_STRIDE API
Browse files Browse the repository at this point in the history
* Addresses spec deviations for API implementation
* Improves run-time checks
* Improves comments
  • Loading branch information
hjabird committed Aug 7, 2024
1 parent 5b5f3f4 commit 128ef68
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
if (a_min - stride_vecs.vec_a.begin() != b_min - stride_vecs.vec_b.begin()) {
throw mkl::unimplemented(
"dft/backends/cufft", __FUNCTION__,
"cufft requires that if ordered by stride length, the order of strides is the same for input and output strides!");
"cufft requires that if ordered by stride length, the order of strides is the same for input/output or fwd/bwd strides!");
}
}
const int a_stride = static_cast<int>(*a_min);
Expand Down
8 changes: 2 additions & 6 deletions src/dft/backends/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "../descriptor.cxx"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(sycl::queue &queue) {
Expand All @@ -41,6 +39,4 @@ template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &)
template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(sycl::queue &);
template void descriptor<precision::DOUBLE, domain::REAL>::commit(sycl::queue &);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
8 changes: 2 additions & 6 deletions src/dft/backends/mklcpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::mklcpu>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
8 changes: 2 additions & 6 deletions src/dft/backends/mklgpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklgpu> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::mklgpu>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
44 changes: 26 additions & 18 deletions src/dft/backends/rocfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,26 +289,34 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
if (dom == dft::domain::REAL) {
lengths_cplx[0] = lengths_cplx[0] / 2 + 1;
}
// When creating real-complex descriptions, the strides will always be wrong for one of the directions.
// When creating real-complex descriptions with INPUT/OUTPUT_STRIDES,
// the strides will always be wrong for one of the directions.
// This is because the least significant dimension is symmetric.
// If the strides are invalid (too small to fit) then just don't bother creating the plan.
const bool vec_a_valid_as_reals =
dimensions == 1 ||
(lengths_cplx[stride_a_indices[0]] <= stride_vecs.vec_a[stride_a_indices[1]] &&
(dimensions == 2 ||
lengths_cplx[stride_a_indices[0]] * lengths_cplx[stride_a_indices[1]] <=
stride_vecs.vec_a[stride_a_indices[2]]));
const bool vec_b_valid_as_reals =
dimensions == 1 ||
(lengths_cplx[stride_b_indices[0]] <= stride_vecs.vec_b[stride_b_indices[1]] &&
(dimensions == 2 ||
lengths_cplx[stride_b_indices[0]] * lengths_cplx[stride_b_indices[1]] <=
stride_vecs.vec_b[stride_b_indices[2]]));
// Test if the stride vector being used as the fwd domain for each direction has valid strides for that use.
bool valid_forward =
stride_vecs.fwd_in == stride_vecs.vec_a && vec_a_valid_as_reals || vec_b_valid_as_reals;
bool valid_backward = stride_vecs.bwd_out == stride_vecs.vec_a && vec_a_valid_as_reals ||
vec_b_valid_as_reals;
auto are_strides_smaller_than_lengths = [=](auto& svec, auto& sindices,
auto& domain_lengths) {
return dimensions == 1 ||
(domain_lengths[sindices[0]] <= svec[sindices[1]] &&
(dimensions == 2 || domain_lengths[sindices[0]] * domain_lengths[sindices[1]] <=
svec[sindices[2]]));
};

const bool vec_a_valid_as_fwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths);
const bool vec_b_valid_as_fwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths);
const bool vec_a_valid_as_bwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths_cplx);
const bool vec_b_valid_as_bwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths_cplx);

// Test if the stride vector being used as the fwd/bwd domain for each direction has valid strides for that use.
bool valid_forward = (stride_vecs.fwd_in == stride_vecs.vec_a &&
vec_a_valid_as_fwd_domain && vec_b_valid_as_bwd_domain) ||
(vec_b_valid_as_fwd_domain && vec_a_valid_as_bwd_domain);
bool valid_backward = (stride_vecs.bwd_in == stride_vecs.vec_a &&
vec_a_valid_as_bwd_domain && vec_b_valid_as_fwd_domain) ||
(vec_b_valid_as_bwd_domain && vec_a_valid_as_fwd_domain);

if (!valid_forward && !valid_backward) {
throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Invalid strides.");
Expand Down
42 changes: 27 additions & 15 deletions src/dft/descriptor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,28 @@ namespace mkl {
namespace dft {
namespace detail {

// Compute the default strides. Modifies real_strides and complex_strides arguments.
// Compute the default strides. Modifies real_strides and complex_strides arguments
template <domain dom>
inline void compute_default_strides(const std::vector<std::int64_t>& dimensions,
std::vector<std::int64_t>& fwd_strides,
std::vector<std::int64_t>& bwd_strides) {
auto rank = dimensions.size();
std::vector<std::int64_t> strides(rank + 1, 1);
for (auto i = rank - 1; i > 0; --i) {
strides[i] = strides[i + 1] * dimensions[i];
const auto rank = dimensions.size();
fwd_strides = std::vector<std::int64_t>(rank + 1, 1);
fwd_strides[0] = 0;
bwd_strides = fwd_strides;
if (rank == 1) {
return;
}

bwd_strides[rank - 1] =
dom == domain::COMPLEX ? dimensions[rank - 1] : (dimensions[rank - 1] / 2) + 1;
fwd_strides[rank - 1] =
dom == domain::COMPLEX ? dimensions[rank - 1] : 2 * bwd_strides[rank - 1];
for (auto i = rank - 1; i > 1; --i) {
// Can't start at rank - 2 with unsigned type and minimum value of rank being 1.
bwd_strides[i - 1] = bwd_strides[i] * dimensions[i - 1];
fwd_strides[i - 1] = fwd_strides[i] * dimensions[i - 1];
}
strides[0] = 0;
// Fwd/Bwd strides and Input/Output strides being the same by default means
// that we don't have to specify if we default to using fwd/bwd strides or
// input/output strides.
bwd_strides = strides;
fwd_strides = std::move(strides);
}

template <precision prec, domain dom>
Expand All @@ -69,12 +76,15 @@ void descriptor<prec, dom>::set_value(config_param param, ...) {
case config_param::PRECISION:
throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter.");
break;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
case config_param::INPUT_STRIDES:
detail::set_value<config_param::INPUT_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
case config_param::OUTPUT_STRIDES:
detail::set_value<config_param::OUTPUT_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
#pragma clang diagnostic pop
case config_param::FWD_STRIDES:
detail::set_value<config_param::FWD_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
Expand Down Expand Up @@ -150,10 +160,9 @@ descriptor<prec, dom>::descriptor(std::vector<std::int64_t> dimensions) {
"Invalid dimension value (negative or 0).");
}
}
compute_default_strides(dimensions, values_.fwd_strides, values_.bwd_strides);
// Assume forward transform.
values_.input_strides = values_.fwd_strides;
values_.output_strides = values_.bwd_strides;
compute_default_strides<dom>(dimensions, values_.fwd_strides, values_.bwd_strides);
values_.input_strides = std::vector<std::int64_t>(dimensions.size() + 1, 0);
values_.output_strides = std::vector<std::int64_t>(dimensions.size() + 1, 0);
values_.bwd_scale = real_t(1.0);
values_.fwd_scale = real_t(1.0);
values_.number_of_transforms = 1;
Expand Down Expand Up @@ -220,6 +229,8 @@ void descriptor<prec, dom>::get_value(config_param param, ...) const {
*va_arg(vl, config_value*) = values_.conj_even_storage;
break;
case config_param::PLACEMENT: *va_arg(vl, config_value*) = values_.placement; break;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
case config_param::INPUT_STRIDES:
std::copy(values_.input_strides.begin(), values_.input_strides.end(),
va_arg(vl, std::int64_t*));
Expand All @@ -228,6 +239,7 @@ void descriptor<prec, dom>::get_value(config_param param, ...) const {
std::copy(values_.output_strides.begin(), values_.output_strides.end(),
va_arg(vl, std::int64_t*));
break;
#pragma clang diagnostic pop
case config_param::FWD_STRIDES:
std::copy(values_.fwd_strides.begin(), values_.fwd_strides.end(),
va_arg(vl, std::int64_t*));
Expand Down
97 changes: 58 additions & 39 deletions tests/unit_tests/dft/include/compute_out_of_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,26 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
return test_skipped;
}

descriptor_t descriptor{ sizes };
auto strides_fwd_cpy = strides_fwd;
auto strides_bwd_cpy = strides_bwd;
if (strides_fwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
else {
strides_fwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
if (strides_bwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
else {
strides_bwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd, strides_bwd);
get_default_distances<domain>(sizes, strides_fwd_cpy, strides_bwd_cpy);
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());

descriptor_t descriptor{ sizes };
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
Expand All @@ -45,22 +60,12 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if (strides_fwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data());
}
if (strides_bwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data());
}
else if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);
std::vector<FwdInputType> fwd_data(
strided_copy(input, sizes, strides_fwd, batches, forward_distance));
strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance));

auto tmp = std::vector<FwdOutputType>(
cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), 0);
cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)), 0);
{
sycl::buffer<FwdInputType, 1> fwd_buf{ fwd_data };
sycl::buffer<FwdOutputType, 1> bwd_buf{ tmp };
Expand All @@ -74,7 +79,7 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<domain == oneapi::mkl::dft::domain::REAL>(
bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes,
strides_bwd, abs_error_margin, rel_error_margin, std::cout));
strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}
}

Expand All @@ -88,9 +93,9 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
[this](auto &x) { x *= static_cast<PrecisionType>(forward_elements); });

for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<false>(fwd_data.data() + forward_distance * i,
input.data() + ref_distance * i, sizes, strides_fwd,
abs_error_margin, rel_error_margin, std::cout));
EXPECT_TRUE(check_equal_strided<false>(
fwd_data.data() + forward_distance * i, input.data() + ref_distance * i, sizes,
strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

return !::testing::Test::HasFailure();
Expand All @@ -103,10 +108,34 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
}
const std::vector<sycl::event> no_dependencies;

auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd, strides_bwd);

descriptor_t descriptor{ sizes };
auto strides_fwd_cpy = strides_fwd;
auto strides_bwd_cpy = strides_bwd;
if (strides_fwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
else {
strides_fwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
if (strides_bwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
else {
strides_bwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd_cpy, strides_bwd_cpy);
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE,
oneapi::mkl::dft::config_value::COMPLEX_COMPLEX);
descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT,
oneapi::mkl::dft::config_value::CCE_FORMAT);
}
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
Expand All @@ -118,36 +147,26 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if (strides_fwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data());
}
if (strides_bwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data());
}
else if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);

auto ua_input = usm_allocator_t<FwdInputType>(cxt, *dev);
auto ua_output = usm_allocator_t<FwdOutputType>(cxt, *dev);

std::vector<FwdInputType, decltype(ua_input)> fwd(
strided_copy(input, sizes, strides_fwd, batches, forward_distance, ua_input), ua_input);
strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance, ua_input), ua_input);
std::vector<FwdOutputType, decltype(ua_output)> bwd(
cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), ua_output);
cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)),
ua_output);

oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd.data(), bwd.data(), no_dependencies)
.wait_and_throw();

auto bwd_ptr = &bwd[0];
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<domain == oneapi::mkl::dft::domain::REAL>(
bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes,
strides_bwd, abs_error_margin, rel_error_margin, std::cout));
strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>, FwdOutputType,
Expand All @@ -160,9 +179,9 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
[this](auto &x) { x *= static_cast<PrecisionType>(forward_elements); });

for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<false>(fwd.data() + forward_distance * i,
input.data() + ref_distance * i, sizes, strides_fwd,
abs_error_margin, rel_error_margin, std::cout));
EXPECT_TRUE(check_equal_strided<false>(
fwd.data() + forward_distance * i, input.data() + ref_distance * i, sizes,
strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

return !::testing::Test::HasFailure();
Expand Down
Loading

0 comments on commit 128ef68

Please sign in to comment.