Skip to content

Commit

Permalink
Remove get_mem for USM
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 12, 2024
1 parent 70cdbe0 commit 8173331
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
47 changes: 21 additions & 26 deletions src/sparse_blas/backends/cusparse/cusparse_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ void init_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, st
sc.get_handle(queue);
auto cuda_value_type = CudaEnumType<fpType>::value;
cusparseDnVecDescr_t cu_dvhandle;
CUSPARSE_ERR_FUNC(cusparseCreateDnVec, &cu_dvhandle, size, sc.get_mem(val),
cuda_value_type);
CUSPARSE_ERR_FUNC(cusparseCreateDnVec, &cu_dvhandle, size, val, cuda_value_type);
*p_dvhandle = new dense_vector_handle(cu_dvhandle, val, size);
});
});
Expand Down Expand Up @@ -104,13 +103,12 @@ void set_dense_vector_data(sycl::queue &queue, dense_vector_handle_t dvhandle, s
if (dvhandle->size != size) {
CUSPARSE_ERR_FUNC(cusparseDestroyDnVec, dvhandle->backend_handle);
auto cuda_value_type = CudaEnumType<fpType>::value;
CUSPARSE_ERR_FUNC(cusparseCreateDnVec, &dvhandle->backend_handle, size,
sc.get_mem(val), cuda_value_type);
CUSPARSE_ERR_FUNC(cusparseCreateDnVec, &dvhandle->backend_handle, size, val,
cuda_value_type);
dvhandle->size = size;
}
else {
CUSPARSE_ERR_FUNC(cusparseDnVecSetValues, dvhandle->backend_handle,
sc.get_mem(val));
CUSPARSE_ERR_FUNC(cusparseDnVecSetValues, dvhandle->backend_handle, val);
}
dvhandle->set_usm_ptr(val);
});
Expand Down Expand Up @@ -162,8 +160,8 @@ void init_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, st
auto cuda_value_type = CudaEnumType<fpType>::value;
auto cuda_order = get_cuda_order(dense_layout);
cusparseDnMatDescr_t cu_dmhandle;
CUSPARSE_ERR_FUNC(cusparseCreateDnMat, &cu_dmhandle, num_rows, num_cols, ld,
sc.get_mem(val), cuda_value_type, cuda_order);
CUSPARSE_ERR_FUNC(cusparseCreateDnMat, &cu_dmhandle, num_rows, num_cols, ld, val,
cuda_value_type, cuda_order);
*p_dmhandle =
new dense_matrix_handle(cu_dmhandle, val, num_rows, num_cols, ld, dense_layout);
});
Expand Down Expand Up @@ -218,15 +216,14 @@ void set_dense_matrix_data(sycl::queue &queue, dense_matrix_handle_t dmhandle,
auto cuda_value_type = CudaEnumType<fpType>::value;
auto cuda_order = get_cuda_order(dense_layout);
CUSPARSE_ERR_FUNC(cusparseCreateDnMat, &dmhandle->backend_handle, num_rows,
num_cols, ld, sc.get_mem(val), cuda_value_type, cuda_order);
num_cols, ld, val, cuda_value_type, cuda_order);
dmhandle->num_rows = num_rows;
dmhandle->num_cols = num_cols;
dmhandle->ld = ld;
dmhandle->dense_layout = dense_layout;
}
else {
CUSPARSE_ERR_FUNC(cusparseDnMatSetValues, dmhandle->backend_handle,
sc.get_mem(val));
CUSPARSE_ERR_FUNC(cusparseDnMatSetValues, dmhandle->backend_handle, val);
}
dmhandle->set_usm_ptr(val);
});
Expand Down Expand Up @@ -285,9 +282,8 @@ void init_coo_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
auto cuda_index_base = get_cuda_index_base(index);
auto cuda_value_type = CudaEnumType<fpType>::value;
cusparseSpMatDescr_t cu_smhandle;
CUSPARSE_ERR_FUNC(cusparseCreateCoo, &cu_smhandle, num_rows, num_cols, nnz,
sc.get_mem(row_ind), sc.get_mem(col_ind), sc.get_mem(val),
cuda_index_type, cuda_index_base, cuda_value_type);
CUSPARSE_ERR_FUNC(cusparseCreateCoo, &cu_smhandle, num_rows, num_cols, nnz, row_ind,
col_ind, val, cuda_index_type, cuda_index_base, cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ind, col_ind, val, num_rows, num_cols,
nnz, index);
});
Expand Down Expand Up @@ -351,16 +347,16 @@ void set_coo_matrix_data(sycl::queue &queue, matrix_handle_t smhandle, std::int6
auto cuda_index_base = get_cuda_index_base(index);
auto cuda_value_type = CudaEnumType<fpType>::value;
CUSPARSE_ERR_FUNC(cusparseCreateCoo, &smhandle->backend_handle, num_rows, num_cols,
nnz, sc.get_mem(row_ind), sc.get_mem(col_ind), sc.get_mem(val),
cuda_index_type, cuda_index_base, cuda_value_type);
nnz, row_ind, col_ind, val, cuda_index_type, cuda_index_base,
cuda_value_type);
smhandle->num_rows = num_rows;
smhandle->num_cols = num_cols;
smhandle->nnz = nnz;
smhandle->index = index;
}
else {
CUSPARSE_ERR_FUNC(cusparseCooSetPointers, smhandle->backend_handle,
sc.get_mem(row_ind), sc.get_mem(col_ind), sc.get_mem(val));
CUSPARSE_ERR_FUNC(cusparseCooSetPointers, smhandle->backend_handle, row_ind,
col_ind, val);
}
smhandle->row_container.set_usm_ptr(row_ind);
smhandle->col_container.set_usm_ptr(col_ind);
Expand Down Expand Up @@ -411,9 +407,9 @@ void init_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
auto cuda_index_base = get_cuda_index_base(index);
auto cuda_value_type = CudaEnumType<fpType>::value;
cusparseSpMatDescr_t cu_smhandle;
CUSPARSE_ERR_FUNC(cusparseCreateCsr, &cu_smhandle, num_rows, num_cols, nnz,
sc.get_mem(row_ptr), sc.get_mem(col_ind), sc.get_mem(val),
cuda_index_type, cuda_index_type, cuda_index_base, cuda_value_type);
CUSPARSE_ERR_FUNC(cusparseCreateCsr, &cu_smhandle, num_rows, num_cols, nnz, row_ptr,
col_ind, val, cuda_index_type, cuda_index_type, cuda_index_base,
cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ptr, col_ind, val, num_rows, num_cols,
nnz, index);
});
Expand Down Expand Up @@ -477,17 +473,16 @@ void set_csr_matrix_data(sycl::queue &queue, matrix_handle_t smhandle, std::int6
auto cuda_index_base = get_cuda_index_base(index);
auto cuda_value_type = CudaEnumType<fpType>::value;
CUSPARSE_ERR_FUNC(cusparseCreateCsr, &smhandle->backend_handle, num_rows, num_cols,
nnz, sc.get_mem(row_ptr), sc.get_mem(col_ind), sc.get_mem(val),
cuda_index_type, cuda_index_type, cuda_index_base,
cuda_value_type);
nnz, row_ptr, col_ind, val, cuda_index_type, cuda_index_type,
cuda_index_base, cuda_value_type);
smhandle->num_rows = num_rows;
smhandle->num_cols = num_cols;
smhandle->nnz = nnz;
smhandle->index = index;
}
else {
CUSPARSE_ERR_FUNC(cusparseCsrSetPointers, smhandle->backend_handle,
sc.get_mem(row_ptr), sc.get_mem(col_ind), sc.get_mem(val));
CUSPARSE_ERR_FUNC(cusparseCsrSetPointers, smhandle->backend_handle, row_ptr,
col_ind, val);
}
smhandle->row_container.set_usm_ptr(row_ptr);
smhandle->col_container.set_usm_ptr(col_ind);
Expand Down
5 changes: 0 additions & 5 deletions src/sparse_blas/backends/cusparse/cusparse_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ class CusparseScopedContextHandler {
auto cudaPtr = ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(acc);
return reinterpret_cast<void *>(cudaPtr);
}

template <typename T>
inline void *get_mem(T *ptr) {
return reinterpret_cast<void *>(ptr);
}
};

} // namespace oneapi::mkl::sparse::cusparse
Expand Down

0 comments on commit 8173331

Please sign in to comment.