Skip to content

Commit

Permalink
Merge branch 'main' into user-after-free
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanZyne committed Mar 14, 2024
2 parents 3d9aa64 + d99d5f7 commit 3c55cac
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 26 deletions.
14 changes: 10 additions & 4 deletions source/adapters/cuda/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,17 +1006,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
ArrayDesc.Format = format;

CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC mipmapDesc = {};
mipmapDesc.numLevels = 1;
mipmapDesc.numLevels = pImageDesc->numMipLevel;
mipmapDesc.arrayDesc = ArrayDesc;

// External memory is mapped to a CUmipmappedArray
// If desired, a CUarray is retrieved from the mipmaps 0th level
CUmipmappedArray memMipMap;
UR_CHECK_ERROR(cuExternalMemoryGetMappedMipmappedArray(
&memMipMap, (CUexternalMemory)hInteropMem, &mipmapDesc));

CUarray memArray;
UR_CHECK_ERROR(cuMipmappedArrayGetLevel(&memArray, memMipMap, 0));
if (pImageDesc->numMipLevel > 1) {
*phImageMem = (ur_exp_image_mem_handle_t)memMipMap;
} else {
CUarray memArray;
UR_CHECK_ERROR(cuMipmappedArrayGetLevel(&memArray, memMipMap, 0));

*phImageMem = (ur_exp_image_mem_handle_t)memArray;
*phImageMem = (ur_exp_image_mem_handle_t)memArray;
}

} catch (ur_result_t Err) {
return Err;
Expand Down
22 changes: 11 additions & 11 deletions source/adapters/cuda/tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@
using tracing_event_t = xpti_td *;
using subscriber_handle_t = CUpti_SubscriberHandle;

using cuptiSubscribe_fn = CUPTIAPI
CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
void *userdata);
using cuptiSubscribe_fn =
CUptiResult(CUPTIAPI *)(CUpti_SubscriberHandle *subscriber,
CUpti_CallbackFunc callback, void *userdata);

using cuptiUnsubscribe_fn = CUPTIAPI
CUptiResult (*)(CUpti_SubscriberHandle subscriber);
using cuptiUnsubscribe_fn =
CUptiResult(CUPTIAPI *)(CUpti_SubscriberHandle subscriber);

using cuptiEnableDomain_fn = CUPTIAPI
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain);
using cuptiEnableDomain_fn =
CUptiResult(CUPTIAPI *)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain);

using cuptiEnableCallback_fn = CUPTIAPI
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
using cuptiEnableCallback_fn =
CUptiResult(CUPTIAPI *)(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);

#define LOAD_CUPTI_SYM(p, lib, x) \
p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
Expand Down
44 changes: 34 additions & 10 deletions source/adapters/opencl/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,36 @@ struct ur_adapter_handle_t_ {
std::mutex Mutex;
};

ur_adapter_handle_t_ adapter{};
static ur_adapter_handle_t_ *adapter = nullptr;

static void globalAdapterShutdown() {
if (cl_ext::ExtFuncPtrCache) {
delete cl_ext::ExtFuncPtrCache;
cl_ext::ExtFuncPtrCache = nullptr;
}
if (adapter) {
delete adapter;
adapter = nullptr;
}
}

UR_APIEXPORT ur_result_t UR_APICALL
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
uint32_t *pNumAdapters) {
if (NumEntries > 0 && phAdapters) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (adapter.RefCount++ == 0) {
cl_ext::ExtFuncPtrCache = std::make_unique<cl_ext::ExtFuncPtrCacheT>();
// Sometimes urAdaterGet may be called after the library already been torn
// down, we also need to create a temporary handle for it.
if (!adapter) {
adapter = new ur_adapter_handle_t_();
atexit(globalAdapterShutdown);
}

*phAdapters = &adapter;
std::lock_guard<std::mutex> Lock{adapter->Mutex};
if (adapter->RefCount++ == 0) {
cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT();
}

*phAdapters = adapter;
}

if (pNumAdapters) {
Expand All @@ -37,14 +55,20 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
++adapter.RefCount;
++adapter->RefCount;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (--adapter.RefCount == 0) {
cl_ext::ExtFuncPtrCache.reset();
// Check first if the adapter is valid pointer
if (adapter) {
std::lock_guard<std::mutex> Lock{adapter->Mutex};
if (--adapter->RefCount == 0) {
if (cl_ext::ExtFuncPtrCache) {
delete cl_ext::ExtFuncPtrCache;
cl_ext::ExtFuncPtrCache = nullptr;
}
}
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -68,7 +92,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_OPENCL);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(adapter.RefCount.load());
return ReturnValue(adapter->RefCount.load());
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ struct ExtFuncPtrCacheT {
// piTeardown to avoid issues with static destruction order (a user application
// might have static objects that indirectly access this cache in their
// destructor).
inline std::unique_ptr<ExtFuncPtrCacheT> ExtFuncPtrCache;
inline ExtFuncPtrCacheT *ExtFuncPtrCache;

// USM helper function to get an extension function pointer
template <typename T>
Expand Down

0 comments on commit 3c55cac

Please sign in to comment.