Skip to content

Commit

Permalink
Avoid placeholder accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 24, 2024
1 parent d23b24d commit 4db2187
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 58 deletions.
75 changes: 37 additions & 38 deletions src/sparse_blas/backends/cusparse/cusparse_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ void submit_host_task_with_acc(sycl::handler &cgh, sycl::queue &queue, Functor f
// specification but should be true for all the implementations. This
// assumption avoids the overhead of resetting the pointer of all data
// handles for each enqueued command.
cgh.require(workspace_placeholder_acc);
cgh.host_task([functor, queue, workspace_placeholder_acc,
capture_only_accessors...](sycl::interop_handle ih) {
auto unused = std::make_tuple(capture_only_accessors...);
Expand Down Expand Up @@ -151,7 +150,6 @@ void submit_native_command_ext_with_acc(sycl::handler &cgh, sycl::queue &queue,
// assumption avoids the overhead of resetting the pointer of all data
// handles for each enqueued command.
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.require(workspace_placeholder_acc);
cgh.ext_codeplay_enqueue_native_command([functor, queue, dependencies,
workspace_placeholder_acc,
capture_only_accessors...](sycl::interop_handle ih) {
Expand Down Expand Up @@ -196,36 +194,36 @@ template <bool UseWorkspace, bool UseEnqueueNativeCommandExt, typename Functor,
sycl::event dispatch_submit_impl_fp_int(const std::string &function_name, sycl::queue queue,
const std::vector<sycl::event> &dependencies,
Functor functor, matrix_handle_t sm_handle,
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc,
sycl::buffer<std::uint8_t> workspace_buffer,
Ts... other_containers) {
if (sm_handle->all_use_buffer()) {
detail::data_type value_type = sm_handle->get_value_type();
detail::data_type int_type = sm_handle->get_int_type();

#define ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \
return queue.submit([&](sycl::handler &cgh) { \
cgh.depends_on(dependencies); \
auto fp_accs = get_fp_accessors<FP_TYPE>(cgh, sm_handle, other_containers...); \
auto int_accs = get_int_accessors<INT_TYPE>(cgh, sm_handle); \
if constexpr (UseWorkspace) { \
if constexpr (UseEnqueueNativeCommandExt) { \
submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \
workspace_placeholder_acc, fp_accs, int_accs); \
} \
else { \
submit_host_task_with_acc(cgh, queue, functor, workspace_placeholder_acc, fp_accs, \
int_accs); \
} \
} \
else { \
(void)workspace_placeholder_acc; \
if constexpr (UseEnqueueNativeCommandExt) { \
submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, int_accs); \
} \
else { \
submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
} \
} \
#define ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \
return queue.submit([&](sycl::handler &cgh) { \
cgh.depends_on(dependencies); \
auto fp_accs = get_fp_accessors<FP_TYPE>(cgh, sm_handle, other_containers...); \
auto int_accs = get_int_accessors<INT_TYPE>(cgh, sm_handle); \
auto workspace_acc = workspace_buffer.get_access<sycl::access::mode::read_write>(cgh); \
if constexpr (UseWorkspace) { \
if constexpr (UseEnqueueNativeCommandExt) { \
submit_native_command_ext_with_acc(cgh, queue, functor, dependencies, \
workspace_acc, fp_accs, int_accs); \
} \
else { \
submit_host_task_with_acc(cgh, queue, functor, workspace_acc, fp_accs, int_accs); \
} \
} \
else { \
(void)workspace_buffer; \
if constexpr (UseEnqueueNativeCommandExt) { \
submit_native_command_ext(cgh, queue, functor, dependencies, fp_accs, int_accs); \
} \
else { \
submit_host_task(cgh, queue, functor, fp_accs, int_accs); \
} \
} \
})
#define ONEMKL_CUSPARSE_SUBMIT_INT(FP_TYPE) \
if (int_type == detail::data_type::int32) { \
Expand Down Expand Up @@ -318,14 +316,12 @@ sycl::event dispatch_submit_impl_fp(const std::string &function_name, sycl::queu
/// Helper function for dispatch_submit_impl_fp_int
template <typename Functor, typename... Ts>
sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue, Functor functor,
matrix_handle_t sm_handle,
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc,
matrix_handle_t sm_handle, sycl::buffer<std::uint8_t> workspace_buffer,
Ts... other_containers) {
constexpr bool UseWorkspace = true;
constexpr bool UseEnqueueNativeCommandExt = false;
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, {}, functor, sm_handle, workspace_placeholder_acc,
other_containers...);
function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
}

/// Helper function for dispatch_submit_impl_fp_int
Expand All @@ -335,8 +331,9 @@ sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue,
matrix_handle_t sm_handle, Ts... other_containers) {
constexpr bool UseWorkspace = false;
constexpr bool UseEnqueueNativeCommandExt = false;
sycl::buffer<std::uint8_t> no_workspace(sycl::range<1>(0));
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, dependencies, functor, sm_handle, {}, other_containers...);
function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
}

/// Helper function for dispatch_submit_impl_fp_int
Expand All @@ -345,15 +342,16 @@ sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue,
matrix_handle_t sm_handle, Ts... other_containers) {
constexpr bool UseWorkspace = false;
constexpr bool UseEnqueueNativeCommandExt = false;
sycl::buffer<std::uint8_t> no_workspace(sycl::range<1>(0));
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, {}, functor, sm_handle, {}, other_containers...);
function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
}

/// Helper function for dispatch_submit_impl_fp_int
template <typename Functor, typename... Ts>
sycl::event dispatch_submit_native_ext(const std::string &function_name, sycl::queue queue,
Functor functor, matrix_handle_t sm_handle,
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc,
sycl::buffer<std::uint8_t> workspace_buffer,
Ts... other_containers) {
constexpr bool UseWorkspace = true;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
Expand All @@ -362,8 +360,7 @@ sycl::event dispatch_submit_native_ext(const std::string &function_name, sycl::q
constexpr bool UseEnqueueNativeCommandExt = false;
#endif
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, {}, functor, sm_handle, workspace_placeholder_acc,
other_containers...);
function_name, queue, {}, functor, sm_handle, workspace_buffer, other_containers...);
}

/// Helper function for dispatch_submit_impl_fp_int
Expand All @@ -378,8 +375,9 @@ sycl::event dispatch_submit_native_ext(const std::string &function_name, sycl::q
#else
constexpr bool UseEnqueueNativeCommandExt = false;
#endif
sycl::buffer<std::uint8_t> no_workspace(sycl::range<1>(0));
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, dependencies, functor, sm_handle, {}, other_containers...);
function_name, queue, dependencies, functor, sm_handle, no_workspace, other_containers...);
}

/// Helper function for dispatch_submit_impl_fp_int
Expand All @@ -393,8 +391,9 @@ sycl::event dispatch_submit_native_ext(const std::string &function_name, sycl::q
#else
constexpr bool UseEnqueueNativeCommandExt = false;
#endif
sycl::buffer<std::uint8_t> no_workspace(sycl::range<1>(0));
return dispatch_submit_impl_fp_int<UseWorkspace, UseEnqueueNativeCommandExt>(
function_name, queue, {}, functor, sm_handle, {}, other_containers...);
function_name, queue, {}, functor, sm_handle, no_workspace, other_containers...);
}

} // namespace oneapi::mkl::sparse::cusparse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::
workspace_ptr, is_alpha_host_accessible);
};

sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(workspace);
dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, B_handle,
C_handle);
dispatch_submit(__func__, queue, functor, A_handle, workspace, B_handle, C_handle);
}

sycl::event spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA,
Expand Down Expand Up @@ -268,10 +266,9 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
auto workspace_ptr = sc.get_mem(workspace_acc);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmm_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
workspace_placeholder_acc, B_handle, C_handle);
spmm_descr->workspace.get_buffer<std::uint8_t>(),
B_handle, C_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a

// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(workspace);
dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, x_handle,
y_handle);
dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle);
}
else {
auto functor = [=](CusparseScopedContextHandler &sc) {
Expand Down Expand Up @@ -284,10 +282,9 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
auto workspace_ptr = sc.get_mem(workspace_acc);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmv_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
workspace_placeholder_acc, x_handle, y_handle);
spmv_descr->workspace.get_buffer<std::uint8_t>(),
x_handle, y_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a

// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(workspace);
dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, x_handle,
y_handle);
dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle);
}
else {
auto functor = [=](CusparseScopedContextHandler &sc) {
Expand Down
5 changes: 0 additions & 5 deletions src/sparse_blas/generic_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,6 @@ struct generic_sparse_handle {
}

void set_matrix_property(matrix_property property) {
if (format == sparse_format::CSR && property == matrix_property::sorted_by_rows) {
throw mkl::invalid_argument(
"sparse_blas", "set_matrix_property",
"Property `matrix_property::sorted_by_rows` is not compatible with CSR format.");
}
properties_mask |= matrix_property_to_mask(property);
}

Expand Down

0 comments on commit 4db2187

Please sign in to comment.