Skip to content

Commit

Permalink
switch to portable fft_enqueue_task
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Oct 11, 2024
1 parent f15983c commit 1df1cb1
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 65 deletions.
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -71,7 +72,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -117,7 +118,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -171,7 +172,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -217,7 +218,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
Expand Down
21 changes: 0 additions & 21 deletions src/dft/backends/cufft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,6 @@ inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, c
return stream;
}


/** Wrap interop API to launch interop host task.
*
* @tparam HandlerT The command group handler type
* @tparam FnT The body of the enqueued task
*
* Either uses host task interop API, or enqueue native command extension.
* This extension avoids host synchronization after
* the CUDA call is complete.
*/
template <typename HandlerT, typename FnT>
static inline void cufft_enqueue_task(HandlerT&& cgh, FnT&& f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih){
#endif
f(std::move(ih));
});
}

} // namespace oneapi::mkl::dft::cufft::detail

#endif
9 changes: 5 additions & 4 deletions src/dft/backends/cufft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "oneapi/mkl/dft/types.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"

#include <cufft.h>

Expand Down Expand Up @@ -74,7 +75,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto inout_native = reinterpret_cast<fwd<descriptor_type> *>(
Expand Down Expand Up @@ -119,7 +120,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

auto in_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -173,7 +174,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down Expand Up @@ -219,7 +220,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(dependencies);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::cufft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, plan);

detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
Expand Down
17 changes: 9 additions & 8 deletions src/dft/backends/rocfft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "oneapi/mkl/dft/descriptor.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"
#include "rocfft_handle.hpp"

#include <rocfft.h>
Expand Down Expand Up @@ -78,7 +79,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

auto inout_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -112,7 +113,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto inout_im_acc = inout_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{
Expand Down Expand Up @@ -146,7 +147,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -181,7 +182,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto out_im_acc = out_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -235,7 +236,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

void *inout_ptr = inout;
Expand Down Expand Up @@ -268,7 +269,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{ inout_re + offsets[0], inout_im + offsets[0] };
Expand Down Expand Up @@ -300,7 +301,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_backward(desc, in, out, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -330,7 +331,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar<descrip
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name =
"compute_backward(desc, in_re, in_im, out_re, out_im, deps)";
auto stream = detail::setup_stream(func_name, ih, info);
Expand Down
20 changes: 0 additions & 20 deletions src/dft/backends/rocfft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,6 @@ inline void execute_checked(const std::string &func, hipStream_t stream, const r
#endif
}

/** Wrap interop API to launch interop host task.
*
* @tparam HandlerT The command group handler type
* @tparam FnT The body of the enqueued task
*
* Either uses host task interop API, or enqueue native command extension.
* This extension avoids host synchronization after
* the CUDA call is complete.
*/
template <typename HandlerT, typename FnT>
static inline void rocfft_enqueue_task(HandlerT&& cgh, FnT&& f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih){
#endif
f(std::move(ih));
});
}

} // namespace oneapi::mkl::dft::rocfft::detail

#endif
17 changes: 9 additions & 8 deletions src/dft/backends/rocfft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "oneapi/mkl/dft/descriptor.hpp"

#include "execute_helper.hpp"
#include "../../execute_helper_generic.hpp"
#include "rocfft_handle.hpp"

#include <rocfft.h>
Expand Down Expand Up @@ -81,7 +82,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_acc = inout.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

auto inout_native = reinterpret_cast<void *>(
Expand Down Expand Up @@ -115,7 +116,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto inout_im_acc = inout_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{
Expand Down Expand Up @@ -148,7 +149,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto out_acc = out.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_forward(desc, in, out)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -183,7 +184,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto out_im_acc = out_im.template get_access<sycl::access::mode::read_write>(cgh);
commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -237,7 +238,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

void *inout_ptr = inout;
Expand Down Expand Up @@ -269,7 +270,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar<descript
sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);
detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
auto stream = detail::setup_stream(func_name, ih, info);

std::array<void *, 2> inout_native{ inout_re + offsets[0], inout_im + offsets[0] };
Expand Down Expand Up @@ -300,7 +301,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name = "compute_forward(desc, in, out, deps)";
auto stream = detail::setup_stream(func_name, ih, info);

Expand Down Expand Up @@ -330,7 +331,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar<descript
cgh.depends_on(deps);
commit->depend_on_last_usm_workspace_event_if_rqd(cgh);

detail::rocfft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
const std::string func_name =
"compute_forward(desc, in_re, in_im, out_re, out_im, deps)";
auto stream = detail::setup_stream(func_name, ih, info);
Expand Down
53 changes: 53 additions & 0 deletions src/dft/execute_helper_generic.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*******************************************************************************
* Copyright Codeplay Software Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef _ONEMKL_DFT_SRC_CUFFT_EXECUTE_GENERIC_HPP_
#define _ONEMKL_DFT_SRC_CUFFT_EXECUTE_GENERIC_HPP_

#if __has_include(<sycl/sycl.hpp>)
#include <sycl/sycl.hpp>
#else
#include <CL/sycl.hpp>
#endif

namespace oneapi::mkl::dft::detail {

/** Wrap interop API to launch interop host task.
*
* @tparam HandlerT The command group handler type
* @tparam FnT The body of the enqueued task
*
* Either uses host task interop API, or enqueue native command extension.
* This extension avoids host synchronization after
* the native call is complete.
*/
template <typename HandlerT, typename FnT>
static inline void fft_enqueue_task(HandlerT&& cgh, FnT&& f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih){
#endif
f(std::move(ih));
});
}

} // namespace oneapi::mkl::dft::detail

#endif

0 comments on commit 1df1cb1

Please sign in to comment.