diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 3e60dc6b0..158d52aa5 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -41,14 +41,14 @@ thread_local cublas_handle CublasScopedContextHandler::handle_helper CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih) : ih(ih), needToRecover_(false) { - auto cudaDevice = sycl::get_native(queue.get_device()); + placedContext_ = new sycl::context(queue.get_context()); + auto cudaDevice = ih.get_native_device(); CUresult err; CUcontext desired; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + // Getting the primary context also sets it as the active context + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { - // Sets the desired context as the active one for the thread - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); // No context is installed and the suggested context is primary // This is the most common case. We can activate the context in the // thread and leave it there until all the PI context referring to the @@ -87,7 +87,7 @@ void ContextCallback(void *userData) { } cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto cudaDevice = sycl::get_native(queue.get_device()); + auto cudaDevice = ih.get_native_device(); CUresult cuErr; CUcontext desired; CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); diff --git a/src/blas/backends/rocblas/rocblas_scope_handle.cpp b/src/blas/backends/rocblas/rocblas_scope_handle.cpp index fa8d3f005..50d2cd944 100644 --- a/src/blas/backends/rocblas/rocblas_scope_handle.cpp +++ b/src/blas/backends/rocblas/rocblas_scope_handle.cpp @@ -57,14 +57,14 @@ RocblasScopedContextHandler::RocblasScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih) : interop_h(ih), needToRecover_(false) { - auto hipDevice = sycl::get_native(queue.get_device()); + placedContext_ = new sycl::context(queue.get_context()); + auto hipDevice = ih.get_native_device(); hipError_t err; hipCtx_t desired; - HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + // Getting the primary context also sets it as the active context + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { - // Sets the desired context as the active one for the thread - HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); // No context is installed and the suggested context is primary // This is the most common case. We can activate the context in the // thread and leave it there until all the PI context referring to the @@ -103,7 +103,7 @@ void ContextCallback(void *userData) { } rocblas_handle RocblasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto hipDevice = sycl::get_native(queue.get_device()); + auto hipDevice = interop_h.get_native_device(); hipError_t hipErr; hipCtx_t desired; HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice); diff --git a/src/dft/backends/cufft/commit.cpp b/src/dft/backends/cufft/commit.cpp index 0eecddbc5..19507d722 100644 --- a/src/dft/backends/cufft/commit.cpp +++ b/src/dft/backends/cufft/commit.cpp @@ -88,10 +88,6 @@ class cufft_commit final : public dft::detail::commit_impl { throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Failed to change cuda context."); } - if (cuCtxSetCurrent(interopContext) != CUDA_SUCCESS) { - throw mkl::exception("dft/backends/cufft", __FUNCTION__, - "Failed to change cuda context."); - } } } diff --git a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp index 06c5666b3..ea2070e4b 100644 --- a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp +++ b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp @@ -42,14 +42,14 @@ CusolverScopedContextHandler::CusolverScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih) : ih(ih), needToRecover_(false) { - auto cudaDevice = sycl::get_native(queue.get_device()); + placedContext_ = new sycl::context(queue.get_context()); + auto cudaDevice = ih.get_native_device(); CUresult err; CUcontext desired; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + // Getting the primary context also sets it as the active context + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { - // Sets the desired context as the active one for the thread - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); // No context is installed and the suggested context is primary // This is the most common case. We can activate the context in the // thread and leave it there until all the PI context referring to the @@ -88,7 +88,7 @@ void ContextCallback(void *userData) { } cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto cudaDevice = sycl::get_native(queue.get_device()); + auto cudaDevice = ih.get_native_device(); CUresult cuErr; CUcontext desired; CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); diff --git a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp index 809ba2763..a096f25bb 100644 --- a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp @@ -44,14 +44,14 @@ RocsolverScopedContextHandler::RocsolverScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih) : ih(ih), needToRecover_(false) { - auto hipDevice = sycl::get_native(queue.get_device()); + placedContext_ = new sycl::context(queue.get_context()); + auto hipDevice = ih.get_native_device(); hipError_t err; hipCtx_t desired; - HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + // Getting the primary context also sets it as the active context + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { - // Sets the desired context as the active one for the thread - HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); // No context is installed and the suggested context is primary // This is the most common case. We can activate the context in the // thread and leave it there until all the PI context referring to the @@ -90,7 +90,7 @@ void ContextCallback(void *userData) { } rocblas_handle RocsolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto hipDevice = sycl::get_native(queue.get_device()); + auto hipDevice = ih.get_native_device(); hipError_t hipErr; hipCtx_t desired; HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice);