From 4b2ac71e785a30a1bd4df0b88c89d6f7e1406fad Mon Sep 17 00:00:00 2001 From: Pierre-Andre Saulais Date: Mon, 20 Nov 2023 11:50:33 +0000 Subject: [PATCH] [CUDA] Move CUPTI function pointers to a separate struct --- source/adapters/cuda/tracing.cpp | 60 ++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/source/adapters/cuda/tracing.cpp b/source/adapters/cuda/tracing.cpp index 272c956958..e3acf03165 100644 --- a/source/adapters/cuda/tracing.cpp +++ b/source/adapters/cuda/tracing.cpp @@ -43,8 +43,8 @@ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber, CUpti_CallbackDomain domain, CUpti_CallbackId cbid); #define LOAD_CUPTI_SYM(p, lib, x) \ - p->x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \ - "cupti" #x); + p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \ + "cupti" #x); #else using tracing_event_t = void *; @@ -55,15 +55,21 @@ using cuptiEnableDomain_fn = void *; using cuptiEnableCallback_fn = void *; #endif // XPTI_ENABLE_INSTRUMENTATION +struct cupti_table_t_ { + cuptiSubscribe_fn Subscribe = nullptr; + cuptiUnsubscribe_fn Unsubscribe = nullptr; + cuptiEnableDomain_fn EnableDomain = nullptr; + cuptiEnableCallback_fn EnableCallback = nullptr; + + bool isInitialized() const; +}; + struct cuda_tracing_context_t_ { tracing_event_t CallEvent = nullptr; tracing_event_t DebugEvent = nullptr; subscriber_handle_t Subscriber = nullptr; ur_loader::LibLoader::Lib Library; - cuptiSubscribe_fn Subscribe = nullptr; - cuptiUnsubscribe_fn Unsubscribe = nullptr; - cuptiEnableDomain_fn EnableDomain = nullptr; - cuptiEnableCallback_fn EnableCallback = nullptr; + cupti_table_t_ Cupti; }; #ifdef XPTI_ENABLE_INSTRUMENTATION @@ -132,6 +138,10 @@ void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) { #endif // XPTI_ENABLE_INSTRUMENTATION } +bool cupti_table_t_::isInitialized() const { + return Subscribe && Unsubscribe && EnableDomain && EnableCallback; +} + bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) { #if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH) if (!Ctx) @@ -141,16 +151,16 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) { auto Lib{ur_loader::LibLoader::loadAdapterLibrary(CUPTI_LIB_PATH)}; if (!Lib) return false; - LOAD_CUPTI_SYM(Ctx, Lib, Subscribe) - LOAD_CUPTI_SYM(Ctx, Lib, Unsubscribe) - LOAD_CUPTI_SYM(Ctx, Lib, EnableDomain) - LOAD_CUPTI_SYM(Ctx, Lib, EnableCallback) - if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain || - !Ctx->EnableCallback) { - unloadCUDATracingLibrary(Ctx); + cupti_table_t_ Table; + LOAD_CUPTI_SYM(Table, Lib, Subscribe) + LOAD_CUPTI_SYM(Table, Lib, Unsubscribe) + LOAD_CUPTI_SYM(Table, Lib, EnableDomain) + LOAD_CUPTI_SYM(Table, Lib, EnableCallback) + if (!Table.isInitialized()) { return false; } Ctx->Library = std::move(Lib); + Ctx->Cupti = Table; return true; #else (void)Ctx; @@ -160,14 +170,10 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) { void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) { #ifdef XPTI_ENABLE_INSTRUMENTATION - if (!Ctx || !Ctx->Library) + if (!Ctx) return; - Ctx->Subscribe = nullptr; - Ctx->Unsubscribe = nullptr; - Ctx->EnableDomain = nullptr; - Ctx->EnableCallback = nullptr; - Ctx->Library.reset(); + Ctx->Cupti = cupti_table_t_(); #else (void)Ctx; #endif // XPTI_ENABLE_INSTRUMENTATION @@ -207,12 +213,12 @@ void enableCUDATracing(cuda_tracing_context_t_ *Ctx) { xptiMakeEvent("CUDA Plugin Debug Layer", &CUDADebugPayload, xpti::trace_algorithm_event, xpti_at::active, &Dummy); - Ctx->Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx); - Ctx->EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API); - Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuGetErrorString); - Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuGetErrorName); + Ctx->Cupti.Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx); + Ctx->Cupti.EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API); + Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API, + CUPTI_DRIVER_TRACE_CBID_cuGetErrorString); + Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API, + CUPTI_DRIVER_TRACE_CBID_cuGetErrorName); #else (void)Ctx; #endif @@ -223,8 +229,8 @@ void disableCUDATracing(cuda_tracing_context_t_ *Ctx) { if (!Ctx || !xptiTraceEnabled()) return; - if (Ctx->Subscriber) { - Ctx->Unsubscribe(Ctx->Subscriber); + if (Ctx->Subscriber && Ctx->Cupti.isInitialized()) { + Ctx->Cupti.Unsubscribe(Ctx->Subscriber); Ctx->Subscriber = nullptr; }