Skip to content

Commit

Permalink
[CUDA] Move CUPTI function pointers to a separate struct
Browse files Browse the repository at this point in the history
  • Loading branch information
pasaulais authored and kbenzie committed Jan 23, 2024
1 parent 3be1c4b commit 4b2ac71
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions source/adapters/cuda/tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);

#define LOAD_CUPTI_SYM(p, lib, x) \
p->x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
"cupti" #x);
p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
"cupti" #x);

#else
using tracing_event_t = void *;
Expand All @@ -55,15 +55,21 @@ using cuptiEnableDomain_fn = void *;
using cuptiEnableCallback_fn = void *;
#endif // XPTI_ENABLE_INSTRUMENTATION

struct cupti_table_t_ {
cuptiSubscribe_fn Subscribe = nullptr;
cuptiUnsubscribe_fn Unsubscribe = nullptr;
cuptiEnableDomain_fn EnableDomain = nullptr;
cuptiEnableCallback_fn EnableCallback = nullptr;

bool isInitialized() const;
};

struct cuda_tracing_context_t_ {
tracing_event_t CallEvent = nullptr;
tracing_event_t DebugEvent = nullptr;
subscriber_handle_t Subscriber = nullptr;
ur_loader::LibLoader::Lib Library;
cuptiSubscribe_fn Subscribe = nullptr;
cuptiUnsubscribe_fn Unsubscribe = nullptr;
cuptiEnableDomain_fn EnableDomain = nullptr;
cuptiEnableCallback_fn EnableCallback = nullptr;
cupti_table_t_ Cupti;
};

#ifdef XPTI_ENABLE_INSTRUMENTATION
Expand Down Expand Up @@ -132,6 +138,10 @@ void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
#endif // XPTI_ENABLE_INSTRUMENTATION
}

bool cupti_table_t_::isInitialized() const {
return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
}

bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
if (!Ctx)
Expand All @@ -141,16 +151,16 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
auto Lib{ur_loader::LibLoader::loadAdapterLibrary(CUPTI_LIB_PATH)};
if (!Lib)
return false;
LOAD_CUPTI_SYM(Ctx, Lib, Subscribe)
LOAD_CUPTI_SYM(Ctx, Lib, Unsubscribe)
LOAD_CUPTI_SYM(Ctx, Lib, EnableDomain)
LOAD_CUPTI_SYM(Ctx, Lib, EnableCallback)
if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain ||
!Ctx->EnableCallback) {
unloadCUDATracingLibrary(Ctx);
cupti_table_t_ Table;
LOAD_CUPTI_SYM(Table, Lib, Subscribe)
LOAD_CUPTI_SYM(Table, Lib, Unsubscribe)
LOAD_CUPTI_SYM(Table, Lib, EnableDomain)
LOAD_CUPTI_SYM(Table, Lib, EnableCallback)
if (!Table.isInitialized()) {
return false;
}
Ctx->Library = std::move(Lib);
Ctx->Cupti = Table;
return true;
#else
(void)Ctx;
Expand All @@ -160,14 +170,10 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {

void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!Ctx || !Ctx->Library)
if (!Ctx)
return;
Ctx->Subscribe = nullptr;
Ctx->Unsubscribe = nullptr;
Ctx->EnableDomain = nullptr;
Ctx->EnableCallback = nullptr;

Ctx->Library.reset();
Ctx->Cupti = cupti_table_t_();
#else
(void)Ctx;
#endif // XPTI_ENABLE_INSTRUMENTATION
Expand Down Expand Up @@ -207,12 +213,12 @@ void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
xptiMakeEvent("CUDA Plugin Debug Layer", &CUDADebugPayload,
xpti::trace_algorithm_event, xpti_at::active, &Dummy);

Ctx->Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx);
Ctx->EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
Ctx->Cupti.Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx);
Ctx->Cupti.EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
#else
(void)Ctx;
#endif
Expand All @@ -223,8 +229,8 @@ void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
if (!Ctx || !xptiTraceEnabled())
return;

if (Ctx->Subscriber) {
Ctx->Unsubscribe(Ctx->Subscriber);
if (Ctx->Subscriber && Ctx->Cupti.isInitialized()) {
Ctx->Cupti.Unsubscribe(Ctx->Subscriber);
Ctx->Subscriber = nullptr;
}

Expand Down

0 comments on commit 4b2ac71

Please sign in to comment.