Skip to content

Commit

Permalink
Merge pull request #1070 from pasaulais/pa/dlopen-cupti
Browse files Browse the repository at this point in the history
[CUDA] Dynamically load the CUPTI library when tracing
  • Loading branch information
kbenzie committed Jan 24, 2024
2 parents 6fb1e54 + 4b2ac71 commit 5b3750d
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 26 deletions.
4 changes: 4 additions & 0 deletions source/adapters/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ else()
message(WARNING "CUDA adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
endif()

if (CUDA_cupti_LIBRARY)
target_compile_definitions("ur_adapter_cuda" PRIVATE CUPTI_LIB_PATH="${CUDA_cupti_LIBRARY}")
endif()

target_link_libraries(${TARGET_NAME} PRIVATE
${PROJECT_NAME}::headers
${PROJECT_NAME}::common
Expand Down
12 changes: 7 additions & 5 deletions source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
#include <ur_api.h>

#include "common.hpp"

void enableCUDATracing();
void disableCUDATracing();
#include "tracing.hpp"

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
};

ur_adapter_handle_t_ adapter{};
Expand All @@ -28,7 +27,8 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
if (NumEntries > 0 && phAdapters) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (adapter.RefCount++ == 0) {
enableCUDATracing();
adapter.TracingCtx = createCUDATracingContext();
enableCUDATracing(adapter.TracingCtx);
}

*phAdapters = &adapter;
Expand All @@ -50,7 +50,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (--adapter.RefCount == 0) {
disableCUDATracing();
disableCUDATracing(adapter.TracingCtx);
freeCUDATracingContext(adapter.TracingCtx);
adapter.TracingCtx = nullptr;
}
return UR_RESULT_SUCCESS;
}
Expand Down
169 changes: 150 additions & 19 deletions source/adapters/cuda/tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,77 @@
#include <cupti.h>
#endif // XPTI_ENABLE_INSTRUMENTATION

#include "tracing.hpp"
#include "ur_lib_loader.hpp"
#include <exception>
#include <iostream>

#ifdef XPTI_ENABLE_INSTRUMENTATION
using tracing_event_t = xpti_td *;
using subscriber_handle_t = CUpti_SubscriberHandle;

using cuptiSubscribe_fn = CUPTIAPI
CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
void *userdata);

using cuptiUnsubscribe_fn = CUPTIAPI
CUptiResult (*)(CUpti_SubscriberHandle subscriber);

using cuptiEnableDomain_fn = CUPTIAPI
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain);

using cuptiEnableCallback_fn = CUPTIAPI
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);

#define LOAD_CUPTI_SYM(p, lib, x) \
p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
"cupti" #x);

#else
using tracing_event_t = void *;
using subscriber_handle_t = void *;
using cuptiSubscribe_fn = void *;
using cuptiUnsubscribe_fn = void *;
using cuptiEnableDomain_fn = void *;
using cuptiEnableCallback_fn = void *;
#endif // XPTI_ENABLE_INSTRUMENTATION

struct cupti_table_t_ {
cuptiSubscribe_fn Subscribe = nullptr;
cuptiUnsubscribe_fn Unsubscribe = nullptr;
cuptiEnableDomain_fn EnableDomain = nullptr;
cuptiEnableCallback_fn EnableCallback = nullptr;

bool isInitialized() const;
};

struct cuda_tracing_context_t_ {
tracing_event_t CallEvent = nullptr;
tracing_event_t DebugEvent = nullptr;
subscriber_handle_t Subscriber = nullptr;
ur_loader::LibLoader::Lib Library;
cupti_table_t_ Cupti;
};

#ifdef XPTI_ENABLE_INSTRUMENTATION
constexpr auto CUDA_CALL_STREAM_NAME = "sycl.experimental.cuda.call";
constexpr auto CUDA_DEBUG_STREAM_NAME = "sycl.experimental.cuda.debug";

thread_local uint64_t CallCorrelationID = 0;
thread_local uint64_t DebugCorrelationID = 0;

static xpti_td *GCallEvent = nullptr;
static xpti_td *GDebugEvent = nullptr;

constexpr auto GVerStr = "0.1";
constexpr int GMajVer = 0;
constexpr int GMinVer = 1;

static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
const void *CBData) {
static void cuptiCallback(void *UserData, CUpti_CallbackDomain,
CUpti_CallbackId CBID, const void *CBData) {
if (xptiTraceEnabled()) {
const auto *CBInfo = static_cast<const CUpti_CallbackData *>(CBData);
cuda_tracing_context_t_ *Ctx =
static_cast<cuda_tracing_context_t_ *>(UserData);

if (CBInfo->callbackSite == CUPTI_API_ENTER) {
CallCorrelationID = xptiGetUniqueId();
Expand All @@ -57,22 +107,95 @@ static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
uint8_t CallStreamID = xptiRegisterStream(CUDA_CALL_STREAM_NAME);
uint8_t DebugStreamID = xptiRegisterStream(CUDA_DEBUG_STREAM_NAME);

xptiNotifySubscribers(CallStreamID, TraceType, GCallEvent, nullptr,
xptiNotifySubscribers(CallStreamID, TraceType, Ctx->CallEvent, nullptr,
CallCorrelationID, FuncName);

xpti::function_with_args_t Payload{
FuncID, FuncName, const_cast<void *>(CBInfo->functionParams),
CBInfo->functionReturnValue, CBInfo->context};
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, GDebugEvent, nullptr,
DebugCorrelationID, &Payload);
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, Ctx->DebugEvent,
nullptr, DebugCorrelationID, &Payload);
}
}
#endif

cuda_tracing_context_t_ *createCUDATracingContext() {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!xptiTraceEnabled())
return nullptr;
return new cuda_tracing_context_t_;
#else
return nullptr;
#endif // XPTI_ENABLE_INSTRUMENTATION
}

void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
#ifdef XPTI_ENABLE_INSTRUMENTATION
unloadCUDATracingLibrary(Ctx);
delete Ctx;
#else
(void)Ctx;
#endif // XPTI_ENABLE_INSTRUMENTATION
}

bool cupti_table_t_::isInitialized() const {
return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
}

bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
if (!Ctx)
return false;
if (Ctx->Library)
return true;
auto Lib{ur_loader::LibLoader::loadAdapterLibrary(CUPTI_LIB_PATH)};
if (!Lib)
return false;
cupti_table_t_ Table;
LOAD_CUPTI_SYM(Table, Lib, Subscribe)
LOAD_CUPTI_SYM(Table, Lib, Unsubscribe)
LOAD_CUPTI_SYM(Table, Lib, EnableDomain)
LOAD_CUPTI_SYM(Table, Lib, EnableCallback)
if (!Table.isInitialized()) {
return false;
}
Ctx->Library = std::move(Lib);
Ctx->Cupti = Table;
return true;
#else
(void)Ctx;
return false;
#endif // XPTI_ENABLE_INSTRUMENTATION && CUPTI_LIB_PATH
}

void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!Ctx)
return;
Ctx->Library.reset();
Ctx->Cupti = cupti_table_t_();
#else
(void)Ctx;
#endif // XPTI_ENABLE_INSTRUMENTATION
}

void enableCUDATracing() {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!xptiTraceEnabled())
return;
static cuda_tracing_context_t_ *Ctx = nullptr;
if (!Ctx)
Ctx = createCUDATracingContext();
enableCUDATracing(Ctx);
#endif
}

void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!Ctx || !xptiTraceEnabled())
return;
else if (!loadCUDATracingLibrary(Ctx))
return;

xptiRegisterStream(CUDA_CALL_STREAM_NAME);
xptiInitialize(CUDA_CALL_STREAM_NAME, GMajVer, GMinVer, GVerStr);
Expand All @@ -81,31 +204,39 @@ void enableCUDATracing() {

uint64_t Dummy;
xpti::payload_t CUDAPayload("CUDA Plugin Layer");
GCallEvent =
Ctx->CallEvent =
xptiMakeEvent("CUDA Plugin Layer", &CUDAPayload,
xpti::trace_algorithm_event, xpti_at::active, &Dummy);

xpti::payload_t CUDADebugPayload("CUDA Plugin Debug Layer");
GDebugEvent =
Ctx->DebugEvent =
xptiMakeEvent("CUDA Plugin Debug Layer", &CUDADebugPayload,
xpti::trace_algorithm_event, xpti_at::active, &Dummy);

CUpti_SubscriberHandle Subscriber;
cuptiSubscribe(&Subscriber, cuptiCallback, nullptr);
cuptiEnableDomain(1, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
Ctx->Cupti.Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx);
Ctx->Cupti.EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
#else
(void)Ctx;
#endif
}

void disableCUDATracing() {
void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!xptiTraceEnabled())
if (!Ctx || !xptiTraceEnabled())
return;

if (Ctx->Subscriber && Ctx->Cupti.isInitialized()) {
Ctx->Cupti.Unsubscribe(Ctx->Subscriber);
Ctx->Subscriber = nullptr;
}

xptiFinalize(CUDA_CALL_STREAM_NAME);
xptiFinalize(CUDA_DEBUG_STREAM_NAME);
#else
(void)Ctx;
#endif // XPTI_ENABLE_INSTRUMENTATION
}
24 changes: 24 additions & 0 deletions source/adapters/cuda/tracing.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===--------- tracing.hpp - CUDA Host API Tracing -------------------------==//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

struct cuda_tracing_context_t_;

cuda_tracing_context_t_ *createCUDATracingContext();
void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx);

bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);
void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);

void enableCUDATracing(cuda_tracing_context_t_ *Ctx);
void disableCUDATracing(cuda_tracing_context_t_ *Ctx);

// Deprecated. Will be removed once pi_cuda has been updated to use the variant
// that takes a context pointer.
void enableCUDATracing();
5 changes: 3 additions & 2 deletions source/common/ur_lib_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class LibLoader {
void operator()(HMODULE handle) { freeAdapterLibrary(handle); }
};

static std::unique_ptr<HMODULE, lib_dtor>
loadAdapterLibrary(const char *name);
using Lib = std::unique_ptr<HMODULE, lib_dtor>;

static Lib loadAdapterLibrary(const char *name);

static void freeAdapterLibrary(HMODULE handle);

Expand Down

0 comments on commit 5b3750d

Please sign in to comment.