Skip to content

Commit

Permalink
Modify handling of CreateWithNative
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Jan 31, 2024
1 parent 65247f1 commit 9e6e28f
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 151 deletions.
12 changes: 7 additions & 5 deletions source/adapters/opencl/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t hNativeContext, uint32_t numDevices, const ur_device_handle_t *phDevices,
const ur_context_native_properties_t *pProperties, ur_context_handle_t *phContext) {
ur_native_handle_t hNativeContext, uint32_t numDevices,
const ur_device_handle_t *phDevices,
const ur_context_native_properties_t *pProperties,
ur_context_handle_t *phContext) {

cl_context NativeHandle = reinterpret_cast<cl_context>(hNativeContext);
auto URContext = std::make_unique<ur_context_handle_t_>(
NativeHandle, numDevices, phDevices);
UR_RETURN_ON_FAILURE(ur_context_handle_t_::makeWithNative(
NativeHandle, numDevices, phDevices, *phContext));

if (!pProperties || !pProperties->isNativeHandleOwned) {
return clRetainContext(NativeHandle);
CL_RETURN_ON_FAILURE(clRetainContext(NativeHandle));
}

return UR_RESULT_SUCCESS;
}

Expand Down
41 changes: 21 additions & 20 deletions source/adapters/opencl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,28 @@ struct ur_context_handle_t_ {
static ur_result_t makeWithNative(native_type Ctx, uint32_t DevCount,
const ur_device_handle_t *phDevices,
ur_context_handle_t &Context) {
if (!phDevices) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}
try {
auto URContext =
std::make_unique<ur_context_handle_t_>(Ctx, DevCount, phDevices);
CL_RETURN_ON_FAILURE(clRetainContext(Ctx));
native_type &NativeContext = URContext->Context;
uint32_t &DeviceCount = URContext->DeviceCount;
if (!DeviceCount) {
CL_RETURN_ON_FAILURE(
clGetContextInfo(NativeContext, CL_CONTEXT_NUM_DEVICES,
sizeof(DeviceCount), &DeviceCount, nullptr));
std::vector<cl_device_id> CLDevices(DeviceCount);
CL_RETURN_ON_FAILURE(clGetContextInfo(NativeContext, CL_CONTEXT_DEVICES,
sizeof(CLDevices),
CLDevices.data(), nullptr));
URContext->Devices.resize(DeviceCount);
for (uint32_t i = 0; i < DeviceCount; i++) {
ur_native_handle_t NativeDevice =
reinterpret_cast<ur_native_handle_t>(CLDevices[i]);
UR_RETURN_ON_FAILURE(urDeviceCreateWithNativeHandle(
NativeDevice, nullptr, nullptr, &(URContext->Devices[i])));
UR_RETURN_ON_FAILURE(urDeviceRetain(URContext->Devices[i]));
uint32_t CLDeviceCount;
CL_RETURN_ON_FAILURE(clGetContextInfo(Ctx, CL_CONTEXT_NUM_DEVICES,
sizeof(CLDeviceCount),
&CLDeviceCount, nullptr));
std::vector<cl_device_id> CLDevices(CLDeviceCount);
CL_RETURN_ON_FAILURE(clGetContextInfo(Ctx, CL_CONTEXT_DEVICES,
sizeof(CLDevices), CLDevices.data(),
nullptr));
if (DevCount != CLDeviceCount) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}
for (uint32_t i = 0; i < DevCount; i++) {
if (phDevices[i]->get() != CLDevices[i]) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}
}
auto URContext =
std::make_unique<ur_context_handle_t_>(Ctx, DevCount, phDevices);
Context = URContext.release();
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
Expand All @@ -81,4 +80,6 @@ struct ur_context_handle_t_ {
}

native_type get() { return Context; }

const std::vector<ur_device_handle_t> &getDevices() { return Devices; }
};
8 changes: 8 additions & 0 deletions source/adapters/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urDevicePartition(
CLSubDevices[i], hDevice->Platform, hDevice);
phSubDevices[i] = URSubDevice.release();
} catch (std::bad_alloc &) {
// Delete all the successfully created subdevices before the failed one.
for (uint32_t j = 0; j < i; j++) {
delete phSubDevices[j];
}
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
} catch (...) {
// Delete all the successfully created subdevices before the failed one.
for (uint32_t j = 0; j < i; j++) {
delete phSubDevices[j];
}
return UR_RESULT_ERROR_UNKNOWN;
}
}
Expand Down
8 changes: 2 additions & 6 deletions source/adapters/opencl/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ struct ur_event_handle_t_ {
ur_queue_handle_t Queue)
: Event(Event), Context(Ctx), Queue(Queue) {
RefCount = 1;
if (Context) {
urContextRetain(Context);
}
urContextRetain(Context);
if (Queue) {
urQueueRetain(Queue);
}
}

~ur_event_handle_t_() {
if (Context) {
urContextRelease(Context);
}
urContextRelease(Context);
if (Queue) {
urQueueRelease(Queue);
}
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
NativeHandle, hProgram, hContext, *phKernel));

if (!pProperties || !pProperties->isNativeHandleOwned) {
CL_RETURN_ON_FAILURE(clRetainKernel((*phKernel)->get()));
CL_RETURN_ON_FAILURE(clRetainKernel(NativeHandle));
}
return UR_RESULT_SUCCESS;
}
Expand Down
51 changes: 19 additions & 32 deletions source/adapters/opencl/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common.hpp"
#include "context.hpp"
#include "program.hpp"

#include <vector>

Expand All @@ -25,22 +26,14 @@ struct ur_kernel_handle_t_ {
ur_context_handle_t Context)
: Kernel(Kernel), Program(Program), Context(Context) {
RefCount = 1;
if (Program) {
urProgramRetain(Program);
}
if (Context) {
urContextRetain(Context);
}
urProgramRetain(Program);
urContextRetain(Context);
}

~ur_kernel_handle_t_() {
clReleaseKernel(Kernel);
if (Program) {
urProgramRelease(Program);
}
if (Context) {
urContextRelease(Context);
}
urProgramRelease(Program);
urContextRelease(Context);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
Expand All @@ -53,33 +46,27 @@ struct ur_kernel_handle_t_ {
ur_program_handle_t Program,
ur_context_handle_t Context,
ur_kernel_handle_t &Kernel) {
if (!Program || !Context) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}
try {
auto URKernel =
std::make_unique<ur_kernel_handle_t_>(NativeKernel, Program, Context);
if (!Program) {
cl_program CLProgram;
CL_RETURN_ON_FAILURE(clGetKernelInfo(NativeKernel, CL_KERNEL_PROGRAM,
sizeof(CLProgram), &CLProgram,
nullptr));
ur_native_handle_t NativeProgram =
reinterpret_cast<ur_native_handle_t>(CLProgram);
UR_RETURN_ON_FAILURE(urProgramCreateWithNativeHandle(
NativeProgram, nullptr, nullptr, &(URKernel->Program)));
UR_RETURN_ON_FAILURE(urProgramRetain(URKernel->Program));
}
cl_context CLContext;
CL_RETURN_ON_FAILURE(clGetKernelInfo(NativeKernel, CL_KERNEL_CONTEXT,
sizeof(CLContext), &CLContext,
nullptr));
if (!Context) {
ur_native_handle_t NativeContext =
reinterpret_cast<ur_native_handle_t>(CLContext);
UR_RETURN_ON_FAILURE(urContextCreateWithNativeHandle(
NativeContext, 0, nullptr, nullptr, &(URKernel->Context)));
UR_RETURN_ON_FAILURE(urContextRetain(URKernel->Context));
} else if (Context->get() != CLContext) {
cl_program CLProgram;
CL_RETURN_ON_FAILURE(clGetKernelInfo(NativeKernel, CL_KERNEL_PROGRAM,
sizeof(CLProgram), &CLProgram,
nullptr));

if (Context->get() != CLContext) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}
if (Program->get() != CLProgram) {
return UR_RESULT_ERROR_INVALID_PROGRAM;
}
auto URKernel =
std::make_unique<ur_kernel_handle_t_>(NativeKernel, Program, Context);
Kernel = URKernel.release();
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/opencl/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
return mapCLErrorToUR(RetErr);
}

UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
return getNativeHandle(hMem->get(), phNativeMem);
}

Expand All @@ -390,7 +390,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
UR_RETURN_ON_FAILURE(
ur_mem_handle_t_::makeWithNative(NativeHandle, hContext, *phMem));
if (!pProperties || !pProperties->isNativeHandleOwned) {
CL_RETURN_ON_FAILURE(clRetainMemObject((*phMem)->get()));
CL_RETURN_ON_FAILURE(clRetainMemObject(NativeHandle));
}
return UR_RESULT_SUCCESS;
}
Expand Down
29 changes: 13 additions & 16 deletions source/adapters/opencl/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include "common.hpp"
#include "context.hpp"

#include <vector>

Expand All @@ -22,16 +23,12 @@ struct ur_mem_handle_t_ {
ur_mem_handle_t_(native_type Mem, ur_context_handle_t Ctx)
: Memory(Mem), Context(Ctx) {
RefCount = 1;
if (Context) {
urContextRetain(Context);
}
urContextRetain(Context);
}

~ur_mem_handle_t_() {
clReleaseMemObject(Memory);
if (Context) {
urContextRelease(Context);
}
urContextRelease(Context);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
Expand All @@ -43,18 +40,18 @@ struct ur_mem_handle_t_ {
static ur_result_t makeWithNative(native_type NativeMem,
ur_context_handle_t Ctx,
ur_mem_handle_t &Mem) {
if (!Ctx) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}
try {
auto URMem = std::make_unique<ur_mem_handle_t_>(NativeMem, Ctx);
if (!Ctx) {
cl_context CLContext;
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(
NativeMem, CL_MEM_CONTEXT, sizeof(CLContext), &CLContext, nullptr));
ur_native_handle_t NativeContext =
reinterpret_cast<ur_native_handle_t>(CLContext);
UR_RETURN_ON_FAILURE(urContextCreateWithNativeHandle(
NativeContext, 0, nullptr, nullptr, &(URMem->Context)));
UR_RETURN_ON_FAILURE(urContextRetain(URMem->Context));
cl_context CLContext;
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(
NativeMem, CL_MEM_CONTEXT, sizeof(CLContext), &CLContext, nullptr));

if (Ctx->get() != CLContext) {
return UR_RESULT_ERROR_INVALID_CONTEXT;
}
auto URMem = std::make_unique<ur_mem_handle_t_>(NativeMem, Ctx);
Mem = URMem.release();
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
Expand Down
1 change: 0 additions & 1 deletion source/adapters/opencl/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ struct ur_platform_handle_t_ {
}

ur_result_t getPlatformVersion(oclv::OpenCLVersion &Version) {

size_t PlatVerSize = 0;
CL_RETURN_ON_FAILURE(clGetPlatformInfo(Platform, CL_PLATFORM_VERSION, 0,
nullptr, &PlatVerSize));
Expand Down
11 changes: 4 additions & 7 deletions source/adapters/opencl/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
if (PlatVer >= oclv::V2_1) {

/* Make sure all devices support CL 2.1 or newer as well. */
for (ur_device_handle_t URDev : hContext->Devices) {
for (ur_device_handle_t URDev : hContext->getDevices()) {
oclv::OpenCLVersion DevVer;

CL_RETURN_ON_FAILURE_AND_SET_NULL(URDev->getDeviceVersion(DevVer),
Expand Down Expand Up @@ -70,7 +70,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
/* If none of the devices conform with CL 2.1 or newer make sure they all
* support the cl_khr_il_program extension.
*/
for (ur_device_handle_t URDev : hContext->Devices) {
for (ur_device_handle_t URDev : hContext->getDevices()) {
bool Supported = false;
CL_RETURN_ON_FAILURE_AND_SET_NULL(
URDev->checkDeviceExtensions({"cl_khr_il_program"}, Supported),
Expand Down Expand Up @@ -178,7 +178,6 @@ static cl_int mapURProgramInfoToCL(ur_program_info_t URPropName) {
UR_APIEXPORT ur_result_t UR_APICALL
urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName,
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {

UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

const cl_program_info CLPropName = mapURProgramInfoToCL(propName);
Expand Down Expand Up @@ -368,7 +367,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
UR_RETURN_ON_FAILURE(
ur_program_handle_t_::makeWithNative(NativeHandle, hContext, *phProgram));
if (!pProperties || !pProperties->isNativeHandleOwned) {
CL_RETURN_ON_FAILURE(clRetainProgram((*phProgram)->get()));
CL_RETURN_ON_FAILURE(clRetainProgram(NativeHandle));
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -386,8 +385,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
return UR_RESULT_ERROR_INVALID_CONTEXT;
}

std::vector<ur_device_handle_t> &DevicesInCtx = Ctx->Devices;

ur_platform_handle_t CurPlatform = Ctx->Devices[0]->Platform;

oclv::OpenCLVersion PlatVer;
Expand All @@ -397,7 +394,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
if (PlatVer < oclv::V2_2) {
UseExtensionLookup = true;
} else {
for (ur_device_handle_t Dev : DevicesInCtx) {
for (ur_device_handle_t Dev : Ctx->getDevices()) {
oclv::OpenCLVersion DevVer;

UR_RETURN_ON_FAILURE(Dev->getDeviceVersion(DevVer));
Expand Down
Loading

0 comments on commit 9e6e28f

Please sign in to comment.