Skip to content

Commit

Permalink
Stopgap
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Sep 10, 2024
1 parent 59edda0 commit 5cb7c69
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
43 changes: 38 additions & 5 deletions cpp/include/kvikio/posix_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,37 @@ namespace kvikio {

namespace detail {

class cuda_context_checker {
public:
static bool is_primary_context_active() noexcept
{
try {
CUcontext current_ctx{};
CUdevice device_handle{};

CUDA_DRIVER_TRY(cudaAPI::instance().CtxGetCurrent(&current_ctx));
if (current_ctx == nullptr) {
int device_idx{0};
CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&device_handle, device_idx));
} else {
CUDA_DRIVER_TRY(cudaAPI::instance().CtxGetDevice(&device_handle));
}

// Whether the current context, if it exists, is a primary or user-created context,
// here we use the primary context to determine if cudaDeviceReset() has been called.
[[maybe_unused]] unsigned int flags{};
int active_state{};
CUDA_DRIVER_TRY(
cudaAPI::instance().DevicePrimaryCtxGetState(device_handle, &flags, &active_state));

return static_cast<bool>(active_state);
} catch (const CUfileException& e) {
std::cerr << e.what() << std::endl;
}
return false;
}
};

/**
* @brief Singleton class to retrieve a CUDA stream for device-host copying
*
Expand All @@ -44,11 +75,13 @@ class StreamsByThread {
StreamsByThread() = default;
~StreamsByThread() noexcept
{
for (auto& [_, stream] : _streams) {
try {
CUDA_DRIVER_TRY(cudaAPI::instance().StreamDestroy(stream));
} catch (const CUfileException& e) {
std::cerr << e.what() << std::endl;
if (cuda_context_checker::is_primary_context_active()) {
for (auto& [_, stream] : _streams) {
try {
CUDA_DRIVER_TRY(cudaAPI::instance().StreamDestroy(stream));
} catch (const CUfileException& e) {
std::cerr << e.what() << std::endl;
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/kvikio/shim/cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class cudaAPI {
decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
decltype(cuStreamCreate)* StreamCreate{nullptr};
decltype(cuStreamDestroy)* StreamDestroy{nullptr};
decltype(cuCtxGetDevice)* CtxGetDevice{nullptr};
decltype(cuDevicePrimaryCtxGetState)* DevicePrimaryCtxGetState{nullptr};

private:
#ifdef KVIKIO_CUDA_FOUND
Expand Down Expand Up @@ -76,6 +78,8 @@ class cudaAPI {
get_symbol(StreamSynchronize, lib, KVIKIO_STRINGIFY(cuStreamSynchronize));
get_symbol(StreamCreate, lib, KVIKIO_STRINGIFY(cuStreamCreate));
get_symbol(StreamDestroy, lib, KVIKIO_STRINGIFY(cuStreamDestroy));
get_symbol(CtxGetDevice, lib, KVIKIO_STRINGIFY(cuCtxGetDevice));
get_symbol(DevicePrimaryCtxGetState, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxGetState));
}
#else
cudaAPI() { throw std::runtime_error("KvikIO not compiled with CUDA support"); }
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/kvikio/shim/cuda_h_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ CUresult cuDevicePrimaryCtxRelease(...);
CUresult cuStreamCreate(...);
CUresult cuStreamDestroy(...);
CUresult cuStreamSynchronize(...);
CUresult cuCtxGetDevice(...);
CUresult cuDevicePrimaryCtxGetState(...);

#endif

0 comments on commit 5cb7c69

Please sign in to comment.