Skip to content

Commit

Permalink
Add reference counting to the UR layer
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Jan 31, 2024
1 parent ea55f6d commit 65247f1
Show file tree
Hide file tree
Showing 18 changed files with 335 additions and 179 deletions.
45 changes: 7 additions & 38 deletions source/adapters/opencl/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
return mapCLErrorToUR(Ret);
}

static cl_int mapURContextInfoToCL(ur_context_info_t URPropName) {

cl_int CLPropName;
switch (URPropName) {
case UR_CONTEXT_INFO_NUM_DEVICES:
CLPropName = CL_CONTEXT_NUM_DEVICES;
break;
case UR_CONTEXT_INFO_DEVICES:
CLPropName = CL_CONTEXT_DEVICES;
break;
case UR_CONTEXT_INFO_REFERENCE_COUNT:
CLPropName = CL_CONTEXT_REFERENCE_COUNT;
break;
default:
CLPropName = -1;
}

return CLPropName;
}

UR_APIEXPORT ur_result_t UR_APICALL
urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {

UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
const cl_int CLPropName = mapURContextInfoToCL(propName);

switch (static_cast<uint32_t>(propName)) {
/* 2D USM memops are not supported. */
Expand All @@ -89,17 +68,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
return ReturnValue(&hContext->Devices[0], hContext->DeviceCount);
}
case UR_CONTEXT_INFO_REFERENCE_COUNT: {
size_t CheckPropSize = 0;
auto ClResult = clGetContextInfo(hContext->get(), CLPropName, propSize,
pPropValue, &CheckPropSize);
if (pPropValue && CheckPropSize != propSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}
CL_RETURN_ON_FAILURE(ClResult);
if (pPropSizeRet) {
*pPropSizeRet = CheckPropSize;
}
return UR_RESULT_SUCCESS;
return ReturnValue(hContext->getReferenceCount());
}
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
Expand All @@ -108,16 +77,16 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {

cl_int Ret = clReleaseContext(hContext->get());
return mapCLErrorToUR(Ret);
if (hContext->decrementReferenceCount() == 0) {
delete hContext;
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urContextRetain(ur_context_handle_t hContext) {

cl_int Ret = clRetainContext(hContext->get());
return mapCLErrorToUR(Ret);
hContext->incrementReferenceCount();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
Expand Down
18 changes: 17 additions & 1 deletion source/adapters/opencl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,31 @@ struct ur_context_handle_t_ {
native_type Context;
std::vector<ur_device_handle_t> Devices;
uint32_t DeviceCount;
std::atomic<uint32_t> RefCount = 0;

ur_context_handle_t_(native_type Ctx, uint32_t DevCount,
const ur_device_handle_t *phDevices)
: Context(Ctx), DeviceCount(DevCount) {
for (uint32_t i = 0; i < DeviceCount; i++) {
Devices.emplace_back(phDevices[i]);
urDeviceRetain(phDevices[i]);
}
RefCount = 1;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

static ur_result_t makeWithNative(native_type Ctx, uint32_t DevCount,
const ur_device_handle_t *phDevices,
ur_context_handle_t &Context) {
try {
auto URContext =
std::make_unique<ur_context_handle_t_>(Ctx, DevCount, phDevices);
CL_RETURN_ON_FAILURE(clRetainContext(Ctx));
native_type &NativeContext = URContext->Context;
uint32_t &DeviceCount = URContext->DeviceCount;
if (!DeviceCount) {
Expand All @@ -50,6 +60,7 @@ struct ur_context_handle_t_ {
reinterpret_cast<ur_native_handle_t>(CLDevices[i]);
UR_RETURN_ON_FAILURE(urDeviceCreateWithNativeHandle(
NativeDevice, nullptr, nullptr, &(URContext->Devices[i])));
UR_RETURN_ON_FAILURE(urDeviceRetain(URContext->Devices[i]));
}
}
Context = URContext.release();
Expand All @@ -62,7 +73,12 @@ struct ur_context_handle_t_ {
return UR_RESULT_SUCCESS;
}

~ur_context_handle_t_() {}
~ur_context_handle_t_() {
for (uint32_t i = 0; i < DeviceCount; i++) {
urDeviceRelease(Devices[i]);
}
clReleaseContext(Context);
}

native_type get() { return Context; }
};
23 changes: 14 additions & 9 deletions source/adapters/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
{"cl_intel_program_scope_host_pipe"}, Supported));
return ReturnValue(Supported);
}
case UR_DEVICE_INFO_REFERENCE_COUNT: {
return ReturnValue(hDevice->getReferenceCount());
}
case UR_DEVICE_INFO_QUEUE_PROPERTIES:
case UR_DEVICE_INFO_QUEUE_ON_DEVICE_PROPERTIES:
case UR_DEVICE_INFO_QUEUE_ON_HOST_PROPERTIES:
Expand Down Expand Up @@ -802,7 +805,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_MAX_SAMPLERS:
case UR_DEVICE_INFO_GLOBAL_MEM_CACHELINE_SIZE:
case UR_DEVICE_INFO_MAX_CONSTANT_ARGS:
case UR_DEVICE_INFO_REFERENCE_COUNT:
case UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES:
case UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
case UR_DEVICE_INFO_GLOBAL_MEM_CACHE_SIZE:
Expand Down Expand Up @@ -958,7 +960,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDevicePartition(
CL_RETURN_ON_FAILURE(clCreateSubDevices(hDevice->get(), CLProperties.data(),
CLNumDevicesRet,
CLSubDevices.data(), nullptr));
for (uint32_t i = 0; i < NumDevices; i++) {
for (uint32_t i = 0; i < std::min(CLNumDevicesRet, NumDevices); i++) {
try {
auto URSubDevice = std::make_unique<ur_device_handle_t_>(
CLSubDevices[i], hDevice->Platform, hDevice);
Expand All @@ -974,19 +976,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urDevicePartition(
return UR_RESULT_SUCCESS;
}

// Root devices ref count are unchanged through out the program lifetime.
UR_APIEXPORT ur_result_t UR_APICALL urDeviceRetain(ur_device_handle_t hDevice) {
if (hDevice->ParentDevice) {
hDevice->incrementReferenceCount();
}

cl_int Result = clRetainDevice(hDevice->get());

return mapCLErrorToUR(Result);
return UR_RESULT_SUCCESS;
}

// Root devices ref count are unchanged through out the program lifetime.
UR_APIEXPORT ur_result_t UR_APICALL
urDeviceRelease(ur_device_handle_t hDevice) {
if (hDevice->ParentDevice && hDevice->decrementReferenceCount() == 0) {
delete hDevice;
}

cl_int Result = clReleaseDevice(hDevice->get());

return mapCLErrorToUR(Result);
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
Expand Down Expand Up @@ -1032,7 +1038,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(
cl_device_id DeviceId = hDevice->get();

// TODO: Cache OpenCL version for each device and platform

auto RetErr = hDevice->getDeviceVersion(DevVer);
CL_RETURN_ON_FAILURE(RetErr);

Expand Down
10 changes: 9 additions & 1 deletion source/adapters/opencl/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ struct ur_device_handle_t_ {
ur_platform_handle_t Platform;
cl_device_type Type = 0;
ur_device_handle_t ParentDevice = nullptr;
std::atomic<uint32_t> RefCount = 0;

ur_device_handle_t_(native_type Dev, ur_platform_handle_t Plat,
ur_device_handle_t Parent)
: Device(Dev), Platform(Plat), ParentDevice(Parent) {
RefCount = 1;
if (Parent) {
Type = Parent->Type;
} else {
Expand All @@ -30,7 +32,13 @@ struct ur_device_handle_t_ {
}
}

~ur_device_handle_t_() {}
~ur_device_handle_t_() { clReleaseDevice(Device); }

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

native_type get() { return Device; }

Expand Down
73 changes: 41 additions & 32 deletions source/adapters/opencl/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventCreateWithNativeHandle(
}

if (!pProperties || !pProperties->isNativeHandleOwned) {
return urEventRetain(*phEvent);
CL_RETURN_ON_FAILURE(clRetainEvent(NativeHandle));
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -137,14 +137,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
cl_int RetErr = clReleaseEvent(hEvent->get());
CL_RETURN_ON_FAILURE(RetErr);
if (hEvent->decrementReferenceCount() == 0) {
delete hEvent;
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
cl_int RetErr = clRetainEvent(hEvent->get());
CL_RETURN_ON_FAILURE(RetErr);
hEvent->incrementReferenceCount();
return UR_RESULT_SUCCESS;
}

Expand All @@ -167,42 +167,51 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
cl_event_info CLEventInfo = convertUREventInfoToCL(propName);
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

if (CLEventInfo == CL_EVENT_CONTEXT) {
switch (propName) {
case UR_EVENT_INFO_CONTEXT: {
return ReturnValue(hEvent->Context);
}
if (CLEventInfo == CL_EVENT_COMMAND_QUEUE) {
case UR_EVENT_INFO_COMMAND_QUEUE: {
return ReturnValue(hEvent->Queue);
}
size_t CheckPropSize = 0;
cl_int RetErr = clGetEventInfo(hEvent->get(), CLEventInfo, propSize,
pPropValue, &CheckPropSize);
if (pPropValue && CheckPropSize != propSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}
CL_RETURN_ON_FAILURE(RetErr);
if (pPropSizeRet) {
*pPropSizeRet = CheckPropSize;
case UR_EVENT_INFO_REFERENCE_COUNT: {
return ReturnValue(hEvent->getReferenceCount());
}
default: {
size_t CheckPropSize = 0;
cl_int RetErr = clGetEventInfo(hEvent->get(), CLEventInfo, propSize,
pPropValue, &CheckPropSize);
if (pPropValue && CheckPropSize != propSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}
CL_RETURN_ON_FAILURE(RetErr);
if (pPropSizeRet) {
*pPropSizeRet = CheckPropSize;
}

if (pPropValue) {
if (propName == UR_EVENT_INFO_COMMAND_TYPE) {
*reinterpret_cast<ur_command_t *>(pPropValue) = convertCLCommandTypeToUR(
*reinterpret_cast<cl_command_type *>(pPropValue));
} else if (propName == UR_EVENT_INFO_COMMAND_EXECUTION_STATUS) {
/* If the CL_EVENT_COMMAND_EXECUTION_STATUS info value is CL_QUEUED,
* change it to CL_SUBMITTED. sycl::info::event::event_command_status has
* no equivalent to CL_QUEUED.
*
* FIXME UR Port: This should not be part of the UR adapter. Since
* PI_QUEUED exists, SYCL RT should be changed to handle this situation.
* In addition, SYCL RT is relying on PI_QUEUED status to make sure that
* the queues are flushed. */
const auto param_value_int = static_cast<ur_event_status_t *>(pPropValue);
if (*param_value_int == UR_EVENT_STATUS_QUEUED) {
*param_value_int = UR_EVENT_STATUS_SUBMITTED;
if (pPropValue) {
if (propName == UR_EVENT_INFO_COMMAND_TYPE) {
*reinterpret_cast<ur_command_t *>(pPropValue) =
convertCLCommandTypeToUR(
*reinterpret_cast<cl_command_type *>(pPropValue));
} else if (propName == UR_EVENT_INFO_COMMAND_EXECUTION_STATUS) {
/* If the CL_EVENT_COMMAND_EXECUTION_STATUS info value is CL_QUEUED,
* change it to CL_SUBMITTED. sycl::info::event::event_command_status
* has no equivalent to CL_QUEUED.
*
* FIXME UR Port: This should not be part of the UR adapter. Since
* PI_QUEUED exists, SYCL RT should be changed to handle this situation.
* In addition, SYCL RT is relying on PI_QUEUED status to make sure that
* the queues are flushed. */
const auto param_value_int =
static_cast<ur_event_status_t *>(pPropValue);
if (*param_value_int == UR_EVENT_STATUS_QUEUED) {
*param_value_int = UR_EVENT_STATUS_SUBMITTED;
}
}
}
}
}

return UR_RESULT_SUCCESS;
}
Expand Down
27 changes: 25 additions & 2 deletions source/adapters/opencl/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,35 @@ struct ur_event_handle_t_ {
native_type Event;
ur_context_handle_t Context;
ur_queue_handle_t Queue;
std::atomic<uint32_t> RefCount = 0;

ur_event_handle_t_(native_type Event, ur_context_handle_t Ctx,
ur_queue_handle_t Queue)
: Event(Event), Context(Ctx), Queue(Queue) {}
: Event(Event), Context(Ctx), Queue(Queue) {
RefCount = 1;
if (Context) {
urContextRetain(Context);
}
if (Queue) {
urQueueRetain(Queue);
}
}

~ur_event_handle_t_() {}
~ur_event_handle_t_() {
if (Context) {
urContextRelease(Context);
}
if (Queue) {
urQueueRelease(Queue);
}
clReleaseEvent(Event);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

native_type get() { return Event; }
};
Loading

0 comments on commit 65247f1

Please sign in to comment.