Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocfft][cufft] DFT update host task to use native command #578

Merged
merged 11 commits into from
Oct 14, 2024
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -71,7 +72,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -117,7 +118,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -171,7 +172,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -217,7 +218,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down
6 changes: 5 additions & 1 deletion src/dft/backends/cufft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,16 @@ void cufft_execute(const std::string &func, CUstream stream, cufftHandle plan, v
}
}
}

#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using the enqueue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
auto result = cuStreamSynchronize(stream);
if (result != CUDA_SUCCESS) {
throw oneapi::mkl::exception("dft/backends/cufft", func,
"cuStreamSynchronize returned " + std::to_string(result));
}
#endif
}

inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, cufftHandle plan) {
Expand Down
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -74,7 +75,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -119,7 +120,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -173,7 +174,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -219,7 +220,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down
42 changes: 18 additions & 24 deletions src/dft/backends/rocfft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "oneapi/mkl/dft/descriptor.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"
#include "rocfft_handle.hpp"

#include <rocfft.h>
Expand Down Expand Up @@ -78,14 +79,13 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

auto inout_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, inout_acc)) +
offsets[0]);
detail::execute_checked(func_name, plan, &inout_native, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_native, nullptr, info);
});
});
}
Expand Down Expand Up @@ -113,7 +113,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_im_acc = inout_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{
Expand All @@ -124,8 +124,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, inout_im_acc)) +
offsets[0])
};
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);
});
});
}
Expand All @@ -148,7 +147,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -158,8 +157,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_native = reinterpret_cast<void *>(
reinterpret_cast<fwd<descriptor_type> *>(detail::native_mem(ih, out_acc)) +
offsets[1]);
detail::execute_checked(func_name, plan, &in_native, &out_native, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_native, &out_native, info);
});
});
}
Expand All @@ -184,7 +182,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_im_acc = out_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand All @@ -204,8 +202,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
detail::native_mem(ih, out_im_acc)) +
offsets[1])
};
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
}
Expand Down Expand Up @@ -239,12 +236,11 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

void *inout_ptr = inout;
detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &inout_ptr, nullptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down Expand Up @@ -273,12 +269,12 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{ inout_re + offsets[0], inout_im + offsets[0] };
detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info);

});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -305,14 +301,13 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

void *in_ptr = in;
void *out_ptr = out;
detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, &in_ptr, &out_ptr, info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand All @@ -336,15 +331,14 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

cgh.host_task([=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name =
"compute_backward(desc, in_re, in_im, out_re, out_im, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> in_native{ in_re + offsets[0], in_im + offsets[0] };
std::array<void *, 2> out_native{ out_re + offsets[1], out_im + offsets[1] };
detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info);
detail::sync_checked(func_name, stream);
detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info);
});
});
commit->set_last_usm_workspace_event_if_rqd(sycl_event);
Expand Down
18 changes: 12 additions & 6 deletions src/dft/backends/rocfft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,26 @@ inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &i
}

inline void sync_checked(const std::string &func, hipStream_t stream) {
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
auto result = hipStreamSynchronize(stream);
if (result != hipSuccess) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"hipStreamSynchronize returned " + std::to_string(result));
}
}

inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[],
inline void execute_checked(const std::string &func, hipStream_t stream, const rocfft_plan plan, void *in_buffer[],
void *out_buffer[], rocfft_execution_info info) {
auto result = rocfft_execute(plan, in_buffer, out_buffer, info);
if (result != rocfft_status_success) {
throw oneapi::mkl::exception("dft/backends/rocfft", func,
"rocfft_execute returned " + std::to_string(result));
}
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
// If not using equeue native extension, the host task must wait on the
// asynchronous operation to complete. Otherwise it report the operation
// as complete early.
sync_checked(func, stream);
#endif
}

} // namespace oneapi::mkl::dft::rocfft::detail
Expand Down
Loading
Loading