Skip to content

Commit

Permalink
Refactor ext-function caching
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Feb 16, 2024
1 parent 456e1c6 commit 9bfbf6e
Show file tree
Hide file tree
Showing 14 changed files with 386 additions and 465 deletions.
8 changes: 2 additions & 6 deletions source/adapters/opencl/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
uint32_t *pNumAdapters) {
if (NumEntries > 0 && phAdapters) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (adapter.RefCount++ == 0) {
cl_ext::ExtFuncPtrCache = std::make_unique<cl_ext::ExtFuncPtrCacheT>();
}
adapter.RefCount++;

*phAdapters = &adapter;
}
Expand All @@ -43,9 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (--adapter.RefCount == 0) {
cl_ext::ExtFuncPtrCache.reset();
}
--adapter.RefCount;
return UR_RESULT_SUCCESS;
}

Expand Down
148 changes: 62 additions & 86 deletions source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "event.hpp"
#include "kernel.hpp"
#include "memory.hpp"
#include "platform.hpp"
#include "queue.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
Expand All @@ -24,15 +25,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
ur_queue_handle_t Queue = nullptr;
UR_RETURN_ON_FAILURE(urQueueCreate(hContext, hDevice, nullptr, &Queue));

cl_context CLContext = hContext->get();
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clCreateCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCreateCommandBufferKHRCache,
cl_ext::CreateCommandBufferName, &clCreateCommandBufferKHR);
ur_platform_handle_t Platform = hDevice->Platform;
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR =
Platform->ExtFuncPtr->clCreateCommandBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCreateCommandBufferKHR,
cl_ext::CreateCommandBufferName,
"cl_khr_command_buffer"));

if (!clCreateCommandBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
cl_int Res = 0;
cl_command_queue CLQueue = Queue->get();
auto CLCommandBuffer = clCreateCommandBufferKHR(1, &CLQueue, nullptr, &Res);
CL_RETURN_ON_FAILURE_AND_SET_NULL(Res, phCommandBuffer);
Expand All @@ -55,14 +55,12 @@ UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
UR_RETURN_ON_FAILURE(urQueueRetain(hCommandBuffer->hInternalQueue));

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clRetainCommandBufferKHR_fn clRetainCommandBuffer = nullptr;
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clRetainCommandBuffer)>(
CLContext, cl_ext::ExtFuncPtrCache->clRetainCommandBufferKHRCache,
cl_ext::RetainCommandBufferName, &clRetainCommandBuffer);

if (!clRetainCommandBuffer || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clRetainCommandBufferKHR_fn clRetainCommandBuffer =
Platform->ExtFuncPtr->clRetainCommandBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clRetainCommandBuffer,
cl_ext::RetainCommandBufferName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(clRetainCommandBuffer(hCommandBuffer->CLCommandBuffer));
return UR_RESULT_SUCCESS;
Expand All @@ -72,15 +70,12 @@ UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
UR_RETURN_ON_FAILURE(urQueueRelease(hCommandBuffer->hInternalQueue));

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clReleaseCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clReleaseCommandBufferKHRCache,
cl_ext::ReleaseCommandBufferName, &clReleaseCommandBufferKHR);

if (!clReleaseCommandBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR =
Platform->ExtFuncPtr->clReleaseCommandBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clReleaseCommandBufferKHR,
cl_ext::ReleaseCommandBufferName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(
clReleaseCommandBufferKHR(hCommandBuffer->CLCommandBuffer));
Expand All @@ -89,15 +84,12 @@ urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clFinalizeCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clFinalizeCommandBufferKHRCache,
cl_ext::FinalizeCommandBufferName, &clFinalizeCommandBufferKHR);

if (!clFinalizeCommandBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR =
Platform->ExtFuncPtr->clFinalizeCommandBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clFinalizeCommandBufferKHR,
cl_ext::FinalizeCommandBufferName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(
clFinalizeCommandBufferKHR(hCommandBuffer->CLCommandBuffer));
Expand All @@ -113,15 +105,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
ur_exp_command_buffer_sync_point_t *pSyncPoint,
ur_exp_command_buffer_command_handle_t *) {

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clCommandNDRangeKernelKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandNDRangeKernelKHRCache,
cl_ext::CommandNRRangeKernelName, &clCommandNDRangeKernelKHR);

if (!clCommandNDRangeKernelKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR =
Platform->ExtFuncPtr->clCommandNDRangeKernelKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandNDRangeKernelKHR,
cl_ext::CommandNRRangeKernelName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(clCommandNDRangeKernelKHR(
hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hKernel->get(),
Expand Down Expand Up @@ -160,14 +149,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr;
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferKHRCache,
cl_ext::CommandCopyBufferName, &clCommandCopyBufferKHR);

if (!clCommandCopyBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR =
Platform->ExtFuncPtr->clCommandCopyBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandCopyBufferKHR,
cl_ext::CommandCopyBufferName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(clCommandCopyBufferKHR(
hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(),
Expand Down Expand Up @@ -195,15 +182,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
size_t OpenCLDstRect[3]{dstOrigin.x, dstOrigin.y, dstOrigin.z};
size_t OpenCLRegion[3]{region.width, region.height, region.depth};

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferRectKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferRectKHRCache,
cl_ext::CommandCopyBufferRectName, &clCommandCopyBufferRectKHR);

if (!clCommandCopyBufferRectKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR =
Platform->ExtFuncPtr->clCommandCopyBufferRectKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandCopyBufferRectKHR,
cl_ext::CommandCopyBufferRectName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(clCommandCopyBufferRectKHR(
hCommandBuffer->CLCommandBuffer, nullptr, hSrcMem->get(), hDstMem->get(),
Expand Down Expand Up @@ -284,14 +268,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr;
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clCommandFillBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandFillBufferKHRCache,
cl_ext::CommandFillBufferName, &clCommandFillBufferKHR);

if (!clCommandFillBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR =
Platform->ExtFuncPtr->clCommandFillBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clCommandFillBufferKHR,
cl_ext::CommandFillBufferName,
"cl_khr_command_buffer"));

CL_RETURN_ON_FAILURE(clCommandFillBufferKHR(
hCommandBuffer->CLCommandBuffer, nullptr, hBuffer->get(), pPattern,
Expand Down Expand Up @@ -340,15 +322,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

cl_context CLContext = hCommandBuffer->hContext->get();
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clEnqueueCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueCommandBufferKHRCache,
cl_ext::EnqueueCommandBufferName, &clEnqueueCommandBufferKHR);

if (!clEnqueueCommandBufferKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR =
Platform->ExtFuncPtr->clEnqueueCommandBufferKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clEnqueueCommandBufferKHR,
cl_ext::EnqueueCommandBufferName,
"cl_khr_command_buffer"));

const uint32_t NumberOfQueues = 1;
cl_event Event;
Expand Down Expand Up @@ -396,15 +375,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(
ur_exp_command_buffer_info_t propName, size_t propSize, void *pPropValue,
size_t *pPropSizeRet) {

cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
cl_ext::clGetCommandBufferInfoKHR_fn clGetCommandBufferInfoKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clGetCommandBufferInfoKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clGetCommandBufferInfoKHRCache,
cl_ext::GetCommandBufferInfoName, &clGetCommandBufferInfoKHR);

if (!clGetCommandBufferInfoKHR || Res != CL_SUCCESS)
return UR_RESULT_ERROR_INVALID_OPERATION;
ur_platform_handle_t Platform = hCommandBuffer->getPlatform();
cl_ext::clGetCommandBufferInfoKHR_fn clGetCommandBufferInfoKHR =
Platform->ExtFuncPtr->clGetCommandBufferInfoKHRCache;
UR_RETURN_ON_FAILURE(Platform->getExtFunc(&clGetCommandBufferInfoKHR,
cl_ext::GetCommandBufferInfoName,
"cl_khr_command_buffer"));

if (propName != UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
Expand Down
4 changes: 4 additions & 0 deletions source/adapters/opencl/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <CL/cl_ext.h>
#include <ur/ur.hpp>

#include "context.hpp"

struct ur_exp_command_buffer_handle_t_ {
ur_queue_handle_t hInternalQueue;
ur_context_handle_t hContext;
Expand All @@ -21,4 +23,6 @@ struct ur_exp_command_buffer_handle_t_ {
cl_command_buffer_khr CLCommandBuffer)
: hInternalQueue(hQueue), hContext(hContext),
CLCommandBuffer(CLCommandBuffer) {}

ur_platform_handle_t getPlatform() { return hContext->Devices[0]->Platform; }
};
104 changes: 0 additions & 104 deletions source/adapters/opencl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,110 +305,6 @@ using clGetCommandBufferInfoKHR_fn = CL_API_ENTRY cl_int(CL_API_CALL *)(
cl_command_buffer_khr command_buffer, cl_command_buffer_info_khr param_name,
size_t param_value_size, void *param_value, size_t *param_value_size_ret);

template <typename T> struct FuncPtrCache {
std::map<cl_context, T> Map;
std::mutex Mutex;
};

// FIXME: There's currently no mechanism for cleaning up this cache, meaning
// that it is invalidated whenever a context is destroyed. This could lead to
// reusing an invalid function pointer if another context happens to have the
// same native handle.
struct ExtFuncPtrCacheT {
FuncPtrCache<clHostMemAllocINTEL_fn> clHostMemAllocINTELCache;
FuncPtrCache<clDeviceMemAllocINTEL_fn> clDeviceMemAllocINTELCache;
FuncPtrCache<clSharedMemAllocINTEL_fn> clSharedMemAllocINTELCache;
FuncPtrCache<clGetDeviceFunctionPointer_fn> clGetDeviceFunctionPointerCache;
FuncPtrCache<clCreateBufferWithPropertiesINTEL_fn>
clCreateBufferWithPropertiesINTELCache;
FuncPtrCache<clMemBlockingFreeINTEL_fn> clMemBlockingFreeINTELCache;
FuncPtrCache<clSetKernelArgMemPointerINTEL_fn>
clSetKernelArgMemPointerINTELCache;
FuncPtrCache<clEnqueueMemFillINTEL_fn> clEnqueueMemFillINTELCache;
FuncPtrCache<clEnqueueMemcpyINTEL_fn> clEnqueueMemcpyINTELCache;
FuncPtrCache<clGetMemAllocInfoINTEL_fn> clGetMemAllocInfoINTELCache;
FuncPtrCache<clEnqueueWriteGlobalVariable_fn>
clEnqueueWriteGlobalVariableCache;
FuncPtrCache<clEnqueueReadGlobalVariable_fn> clEnqueueReadGlobalVariableCache;
FuncPtrCache<clEnqueueReadHostPipeINTEL_fn> clEnqueueReadHostPipeINTELCache;
FuncPtrCache<clEnqueueWriteHostPipeINTEL_fn> clEnqueueWriteHostPipeINTELCache;
FuncPtrCache<clSetProgramSpecializationConstant_fn>
clSetProgramSpecializationConstantCache;
FuncPtrCache<clCreateCommandBufferKHR_fn> clCreateCommandBufferKHRCache;
FuncPtrCache<clRetainCommandBufferKHR_fn> clRetainCommandBufferKHRCache;
FuncPtrCache<clReleaseCommandBufferKHR_fn> clReleaseCommandBufferKHRCache;
FuncPtrCache<clFinalizeCommandBufferKHR_fn> clFinalizeCommandBufferKHRCache;
FuncPtrCache<clCommandNDRangeKernelKHR_fn> clCommandNDRangeKernelKHRCache;
FuncPtrCache<clCommandCopyBufferKHR_fn> clCommandCopyBufferKHRCache;
FuncPtrCache<clCommandCopyBufferRectKHR_fn> clCommandCopyBufferRectKHRCache;
FuncPtrCache<clCommandFillBufferKHR_fn> clCommandFillBufferKHRCache;
FuncPtrCache<clEnqueueCommandBufferKHR_fn> clEnqueueCommandBufferKHRCache;
FuncPtrCache<clGetCommandBufferInfoKHR_fn> clGetCommandBufferInfoKHRCache;
};
// A raw pointer is used here since the lifetime of this map has to be tied to
// piTeardown to avoid issues with static destruction order (a user application
// might have static objects that indirectly access this cache in their
// destructor).
inline std::unique_ptr<ExtFuncPtrCacheT> ExtFuncPtrCache;

// USM helper function to get an extension function pointer
template <typename T>
static ur_result_t getExtFuncFromContext(cl_context Context,
FuncPtrCache<T> &FPtrCache,
const char *FuncName, T *Fptr) {
// TODO
// Potentially redo caching as UR interface changes.
// if cached, return cached FuncPtr
std::lock_guard<std::mutex> CacheLock{FPtrCache.Mutex};
std::map<cl_context, T> &FPtrMap = FPtrCache.Map;
auto It = FPtrMap.find(Context);
if (It != FPtrMap.end()) {
auto F = It->second;
// if cached that extension is not available return nullptr and
// UR_RESULT_ERROR_INVALID_VALUE
*Fptr = F;
return F ? UR_RESULT_SUCCESS : UR_RESULT_ERROR_INVALID_VALUE;
}

cl_uint DeviceCount;
cl_int RetErr = clGetContextInfo(Context, CL_CONTEXT_NUM_DEVICES,
sizeof(cl_uint), &DeviceCount, nullptr);

if (RetErr != CL_SUCCESS || DeviceCount < 1) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}

std::vector<cl_device_id> DevicesInCtx(DeviceCount);
RetErr = clGetContextInfo(Context, CL_CONTEXT_DEVICES,
DeviceCount * sizeof(cl_device_id),
DevicesInCtx.data(), nullptr);

if (RetErr != CL_SUCCESS) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}

cl_platform_id CurPlatform;
RetErr = clGetDeviceInfo(DevicesInCtx[0], CL_DEVICE_PLATFORM,
sizeof(cl_platform_id), &CurPlatform, nullptr);

if (RetErr != CL_SUCCESS) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}

T FuncPtr = reinterpret_cast<T>(
clGetExtensionFunctionAddressForPlatform(CurPlatform, FuncName));

if (!FuncPtr) {
// Cache that the extension is not available
FPtrMap[Context] = nullptr;
return UR_RESULT_ERROR_INVALID_VALUE;
}

*Fptr = FuncPtr;
FPtrMap[Context] = FuncPtr;

return UR_RESULT_SUCCESS;
}
} // namespace cl_ext

ur_result_t mapCLErrorToUR(cl_int Result);
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/opencl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,7 @@ struct ur_context_handle_t_ {

native_type get() { return Context; }

ur_platform_handle_t getPlatform() { return Devices[0]->Platform; }

const std::vector<ur_device_handle_t> &getDevices() { return Devices; }
};
1 change: 0 additions & 1 deletion source/adapters/opencl/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#pragma once

#include "common.hpp"
#include "platform.hpp"

struct ur_device_handle_t_ {
using native_type = cl_device_id;
Expand Down
Loading

0 comments on commit 9bfbf6e

Please sign in to comment.