diff --git a/cpp/include/kvikio/posix_io.hpp b/cpp/include/kvikio/posix_io.hpp index 9e88a3e265..3aa9676fb5 100644 --- a/cpp/include/kvikio/posix_io.hpp +++ b/cpp/include/kvikio/posix_io.hpp @@ -30,6 +30,37 @@ namespace kvikio { namespace detail { +class cuda_context_checker { + public: + static bool is_primary_context_active() noexcept + { + try { + CUcontext current_ctx{}; + CUdevice device_handle{}; + + CUDA_DRIVER_TRY(cudaAPI::instance().CtxGetCurrent(¤t_ctx)); + if (current_ctx == nullptr) { + int device_idx{0}; + CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&device_handle, device_idx)); + } else { + CUDA_DRIVER_TRY(cudaAPI::instance().CtxGetDevice(&device_handle)); + } + + // Whether the current context, if it exists, is a primary or user-created context, + // here we use the primary context to determine if cudaDeviceReset() has been called. + [[maybe_unused]] unsigned int flags{}; + int active_state{}; + CUDA_DRIVER_TRY( + cudaAPI::instance().DevicePrimaryCtxGetState(device_handle, &flags, &active_state)); + + return static_cast(active_state); + } catch (const CUfileException& e) { + std::cerr << e.what() << std::endl; + } + return false; + } +}; + /** * @brief Singleton class to retrieve a CUDA stream for device-host copying * @@ -44,11 +75,13 @@ class StreamsByThread { StreamsByThread() = default; ~StreamsByThread() noexcept { - for (auto& [_, stream] : _streams) { - try { - CUDA_DRIVER_TRY(cudaAPI::instance().StreamDestroy(stream)); - } catch (const CUfileException& e) { - std::cerr << e.what() << std::endl; + if (cuda_context_checker::is_primary_context_active()) { + for (auto& [_, stream] : _streams) { + try { + CUDA_DRIVER_TRY(cudaAPI::instance().StreamDestroy(stream)); + } catch (const CUfileException& e) { + std::cerr << e.what() << std::endl; + } } } } diff --git a/cpp/include/kvikio/shim/cuda.hpp b/cpp/include/kvikio/shim/cuda.hpp index 5d42bd0dcb..e1bf607819 100644 --- a/cpp/include/kvikio/shim/cuda.hpp +++ b/cpp/include/kvikio/shim/cuda.hpp @@ -48,6 +48,8 @@ class cudaAPI { decltype(cuStreamSynchronize)* StreamSynchronize{nullptr}; decltype(cuStreamCreate)* StreamCreate{nullptr}; decltype(cuStreamDestroy)* StreamDestroy{nullptr}; + decltype(cuCtxGetDevice)* CtxGetDevice{nullptr}; + decltype(cuDevicePrimaryCtxGetState)* DevicePrimaryCtxGetState{nullptr}; private: #ifdef KVIKIO_CUDA_FOUND @@ -76,6 +78,8 @@ class cudaAPI { get_symbol(StreamSynchronize, lib, KVIKIO_STRINGIFY(cuStreamSynchronize)); get_symbol(StreamCreate, lib, KVIKIO_STRINGIFY(cuStreamCreate)); get_symbol(StreamDestroy, lib, KVIKIO_STRINGIFY(cuStreamDestroy)); + get_symbol(CtxGetDevice, lib, KVIKIO_STRINGIFY(cuCtxGetDevice)); + get_symbol(DevicePrimaryCtxGetState, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxGetState)); } #else cudaAPI() { throw std::runtime_error("KvikIO not compiled with CUDA support"); } diff --git a/cpp/include/kvikio/shim/cuda_h_wrapper.hpp b/cpp/include/kvikio/shim/cuda_h_wrapper.hpp index 0740c99f31..60c67f16dd 100644 --- a/cpp/include/kvikio/shim/cuda_h_wrapper.hpp +++ b/cpp/include/kvikio/shim/cuda_h_wrapper.hpp @@ -63,5 +63,7 @@ CUresult cuDevicePrimaryCtxRelease(...); CUresult cuStreamCreate(...); CUresult cuStreamDestroy(...); CUresult cuStreamSynchronize(...); +CUresult cuCtxGetDevice(...); +CUresult cuDevicePrimaryCtxGetState(...); #endif