Skip to content

Commit

Permalink
Add kernel handle
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Dec 20, 2023
1 parent 9cf7e2f commit f7b0de5
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 63 deletions.
8 changes: 4 additions & 4 deletions source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "common.hpp"
#include "context.hpp"
#include "event.hpp"
#include "kernel.hpp"
#include "memory.hpp"
#include "queue.hpp"

Expand Down Expand Up @@ -120,10 +121,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
return UR_RESULT_ERROR_INVALID_OPERATION;

CL_RETURN_ON_FAILURE(clCommandNDRangeKernelKHR(
hCommandBuffer->CLCommandBuffer, nullptr, nullptr,
cl_adapter::cast<cl_kernel>(hKernel), workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize, numSyncPointsInWaitList,
pSyncPointWaitList, pSyncPoint, nullptr));
hCommandBuffer->CLCommandBuffer, nullptr, nullptr, hKernel->get(),
workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint, nullptr));

return UR_RESULT_SUCCESS;
}
Expand Down
9 changes: 5 additions & 4 deletions source/adapters/opencl/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "common.hpp"
#include "context.hpp"
#include "event.hpp"
#include "kernel.hpp"
#include "memory.hpp"
#include "program.hpp"
#include "queue.hpp"
Expand Down Expand Up @@ -40,10 +41,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (uint32_t i = 0; i < numEventsInWaitList; i++) {
CLWaitEvents[i] = phEventWaitList[i]->get();
}
CL_RETURN_ON_FAILURE(clEnqueueNDRangeKernel(
hQueue->get(), cl_adapter::cast<cl_kernel>(hKernel), workDim,
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList,
CLWaitEvents.data(), &Event));
CL_RETURN_ON_FAILURE(
clEnqueueNDRangeKernel(hQueue->get(), hKernel->get(), workDim,
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, CLWaitEvents.data(), &Event));
if (phEvent) {
auto UREvent =
std::make_unique<ur_event_handle_t_>(Event, hQueue->Context, hQueue);
Expand Down
107 changes: 57 additions & 50 deletions source/adapters/opencl/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "kernel.hpp"
#include "common.hpp"
#include "device.hpp"
#include "memory.hpp"
Expand All @@ -21,19 +22,20 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
ur_kernel_handle_t *phKernel) {

cl_int CLResult;
*phKernel = cl_adapter::cast<ur_kernel_handle_t>(
clCreateKernel(hProgram->get(), pKernelName, &CLResult));
cl_kernel Kernel = clCreateKernel(hProgram->get(), pKernelName, &CLResult);
CL_RETURN_ON_FAILURE(CLResult);
auto URKernel =
std::make_unique<ur_kernel_handle_t_>(Kernel, hProgram, nullptr);
*phKernel = URKernel.release();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {

CL_RETURN_ON_FAILURE(clSetKernelArg(cl_adapter::cast<cl_kernel>(hKernel),
cl_adapter::cast<cl_uint>(argIndex),
argSize, pArgValue));
CL_RETURN_ON_FAILURE(clSetKernelArg(
hKernel->get(), cl_adapter::cast<cl_uint>(argIndex), argSize, pArgValue));

return UR_RESULT_SUCCESS;
}
Expand All @@ -42,9 +44,8 @@ UR_APIEXPORT ur_result_t UR_APICALL
urKernelSetArgLocal(ur_kernel_handle_t hKernel, uint32_t argIndex,
size_t argSize, const ur_kernel_arg_local_properties_t *) {

CL_RETURN_ON_FAILURE(clSetKernelArg(cl_adapter::cast<cl_kernel>(hKernel),
cl_adapter::cast<cl_uint>(argIndex),
argSize, nullptr));
CL_RETURN_ON_FAILURE(clSetKernelArg(
hKernel->get(), cl_adapter::cast<cl_uint>(argIndex), argSize, nullptr));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -76,26 +77,31 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
size_t propSize,
void *pPropValue,
size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
// We need this little bit of ugliness because the UR NUM_ARGS property is
// size_t whereas the CL one is cl_uint. We should consider changing that see
// #1038
if (propName == UR_KERNEL_INFO_NUM_ARGS) {
if (pPropSizeRet)
*pPropSizeRet = sizeof(size_t);
cl_uint NumArgs = 0;
CL_RETURN_ON_FAILURE(clGetKernelInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_RETURN_ON_FAILURE(clGetKernelInfo(hKernel->get(),
mapURKernelInfoToCL(propName),
sizeof(NumArgs), &NumArgs, nullptr));
if (pPropValue) {
if (propSize != sizeof(size_t))
return UR_RESULT_ERROR_INVALID_SIZE;
*static_cast<size_t *>(pPropValue) = static_cast<size_t>(NumArgs);
}
} else if (propName == UR_KERNEL_INFO_PROGRAM) {
return ReturnValue(hKernel->Program);
} else if (propName == UR_KERNEL_INFO_CONTEXT) {
return ReturnValue(hKernel->Context);
} else {
size_t CheckPropSize = 0;
cl_int ClResult = clGetKernelInfo(cl_adapter::cast<cl_kernel>(hKernel),
mapURKernelInfoToCL(propName), propSize,
pPropValue, &CheckPropSize);
cl_int ClResult =
clGetKernelInfo(hKernel->get(), mapURKernelInfoToCL(propName), propSize,
pPropValue, &CheckPropSize);
if (pPropValue && CheckPropSize != propSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}
Expand Down Expand Up @@ -147,8 +153,8 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}
}
CL_RETURN_ON_FAILURE(clGetKernelWorkGroupInfo(
cl_adapter::cast<cl_kernel>(hKernel), hDevice->get(),
mapURKernelGroupInfoToCL(propName), propSize, pPropValue, pPropSizeRet));
hKernel->get(), hDevice->get(), mapURKernelGroupInfoToCL(propName),
propSize, pPropValue, pPropSizeRet));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -201,9 +207,8 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}

cl_int Ret = clGetKernelSubGroupInfo(
cl_adapter::cast<cl_kernel>(hKernel), hDevice->get(),
mapURKernelSubGroupInfoToCL(propName), InputValueSize, InputValue.get(),
sizeof(size_t), &RetVal, pPropSizeRet);
hKernel->get(), hDevice->get(), mapURKernelSubGroupInfoToCL(propName),
InputValueSize, InputValue.get(), sizeof(size_t), &RetVal, pPropSizeRet);

if (Ret == CL_INVALID_OPERATION) {
// clGetKernelSubGroupInfo returns CL_INVALID_OPERATION if the device does
Expand Down Expand Up @@ -252,13 +257,13 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
CL_RETURN_ON_FAILURE(clRetainKernel(cl_adapter::cast<cl_kernel>(hKernel)));
CL_RETURN_ON_FAILURE(clRetainKernel(hKernel->get()));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelRelease(ur_kernel_handle_t hKernel) {
CL_RETURN_ON_FAILURE(clReleaseKernel(cl_adapter::cast<cl_kernel>(hKernel)));
CL_RETURN_ON_FAILURE(clReleaseKernel(hKernel->get()));
return UR_RESULT_SUCCESS;
}

Expand All @@ -276,41 +281,38 @@ static ur_result_t usmSetIndirectAccess(ur_kernel_handle_t hKernel) {

/* We test that each alloc type is supported before we actually try to set
* KernelExecInfo. */
CL_RETURN_ON_FAILURE(clGetKernelInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_CONTEXT, sizeof(cl_context),
&CLContext, nullptr));
CL_RETURN_ON_FAILURE(clGetKernelInfo(hKernel->get(), CL_KERNEL_CONTEXT,
sizeof(cl_context), &CLContext,
nullptr));

UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache,
cl_ext::HostMemAllocName, &HFunc));

if (HFunc) {
CL_RETURN_ON_FAILURE(
clSetKernelExecInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
CL_RETURN_ON_FAILURE(clSetKernelExecInfo(
hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
}

UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clDeviceMemAllocINTELCache,
cl_ext::DeviceMemAllocName, &DFunc));

if (DFunc) {
CL_RETURN_ON_FAILURE(
clSetKernelExecInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
CL_RETURN_ON_FAILURE(clSetKernelExecInfo(
hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
}

UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clSharedMemAllocINTELCache,
cl_ext::SharedMemAllocName, &SFunc));

if (SFunc) {
CL_RETURN_ON_FAILURE(
clSetKernelExecInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
CL_RETURN_ON_FAILURE(clSetKernelExecInfo(
hKernel->get(), CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal));
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -332,9 +334,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
return UR_RESULT_SUCCESS;
}
case UR_KERNEL_EXEC_INFO_USM_PTRS: {
CL_RETURN_ON_FAILURE(clSetKernelExecInfo(
cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, propSize, pPropValue));
CL_RETURN_ON_FAILURE(clSetKernelExecInfo(hKernel->get(),
CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL,
propSize, pPropValue));
return UR_RESULT_SUCCESS;
}
default: {
Expand All @@ -348,9 +350,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {

cl_context CLContext;
CL_RETURN_ON_FAILURE(clGetKernelInfo(cl_adapter::cast<cl_kernel>(hKernel),
CL_KERNEL_CONTEXT, sizeof(cl_context),
&CLContext, nullptr));
CL_RETURN_ON_FAILURE(clGetKernelInfo(hKernel->get(), CL_KERNEL_CONTEXT,
sizeof(cl_context), &CLContext,
nullptr));

clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
UR_RETURN_ON_FAILURE(
Expand All @@ -364,25 +366,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
* deref the arg to get the pointer value */
auto PtrToPtr = reinterpret_cast<const intptr_t *>(pArgValue);
auto DerefPtr = reinterpret_cast<void *>(*PtrToPtr);
CL_RETURN_ON_FAILURE(FuncPtr(cl_adapter::cast<cl_kernel>(hKernel),
cl_adapter::cast<cl_uint>(argIndex),
DerefPtr));
CL_RETURN_ON_FAILURE(
FuncPtr(hKernel->get(), cl_adapter::cast<cl_uint>(argIndex), DerefPtr));
}

return UR_RESULT_SUCCESS;
}
UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
ur_kernel_handle_t hKernel, ur_native_handle_t *phNativeKernel) {

*phNativeKernel = reinterpret_cast<ur_native_handle_t>(hKernel);
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(hKernel->get());
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
ur_native_handle_t hNativeKernel, ur_context_handle_t, ur_program_handle_t,
ur_native_handle_t hNativeKernel, ur_context_handle_t hContext,
ur_program_handle_t hProgram,
const ur_kernel_native_properties_t *pProperties,
ur_kernel_handle_t *phKernel) {
*phKernel = reinterpret_cast<ur_kernel_handle_t>(hNativeKernel);
cl_kernel NativeHandle = reinterpret_cast<cl_kernel>(hNativeKernel);
auto URKernel =
std::make_unique<ur_kernel_handle_t_>(NativeHandle, hProgram, hContext);
UR_RETURN_ON_FAILURE(URKernel->initWithNative());
*phKernel = URKernel.release();

if (!pProperties || !pProperties->isNativeHandleOwned) {
return urKernelRetain(*phKernel);
}
Expand All @@ -394,7 +401,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
const ur_kernel_arg_mem_obj_properties_t *, ur_mem_handle_t hArgValue) {

cl_mem CLArgValue = hArgValue ? hArgValue->get() : nullptr;
CL_RETURN_ON_FAILURE(clSetKernelArg(cl_adapter::cast<cl_kernel>(hKernel),
CL_RETURN_ON_FAILURE(clSetKernelArg(hKernel->get(),
cl_adapter::cast<cl_uint>(argIndex),
sizeof(CLArgValue), &CLArgValue));
return UR_RESULT_SUCCESS;
Expand All @@ -405,9 +412,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
const ur_kernel_arg_sampler_properties_t *, ur_sampler_handle_t hArgValue) {

cl_sampler CLArgSampler = hArgValue->get();
cl_int RetErr = clSetKernelArg(cl_adapter::cast<cl_kernel>(hKernel),
cl_adapter::cast<cl_uint>(argIndex),
sizeof(CLArgSampler), &CLArgSampler);
cl_int RetErr =
clSetKernelArg(hKernel->get(), cl_adapter::cast<cl_uint>(argIndex),
sizeof(CLArgSampler), &CLArgSampler);
CL_RETURN_ON_FAILURE(RetErr);
return UR_RESULT_SUCCESS;
}
51 changes: 51 additions & 0 deletions source/adapters/opencl/kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===--------- kernel.hpp - OpenCL Adapter ---------------------------===//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include "common.hpp"

#include <vector>

struct ur_kernel_handle_t_ {
using native_type = cl_kernel;
native_type Kernel;
ur_program_handle_t Program;
ur_context_handle_t Context;

ur_kernel_handle_t_(native_type Kernel, ur_program_handle_t Program,
ur_context_handle_t Context)
: Kernel(Kernel), Program(Program), Context(Context) {}

~ur_kernel_handle_t_() {}

ur_result_t initWithNative() {
if (!Program) {
cl_program CLProgram;
CL_RETURN_ON_FAILURE(clGetKernelInfo(
Kernel, CL_KERNEL_PROGRAM, sizeof(CLProgram), &CLProgram, nullptr));
ur_native_handle_t NativeProgram =
reinterpret_cast<ur_native_handle_t>(CLProgram);
UR_RETURN_ON_FAILURE(urProgramCreateWithNativeHandle(
NativeProgram, nullptr, nullptr, &Program));
}
if (!Context) {
cl_context CLContext;
CL_RETURN_ON_FAILURE(clGetKernelInfo(
Kernel, CL_KERNEL_CONTEXT, sizeof(CLContext), &CLContext, nullptr));
ur_native_handle_t NativeContext =
reinterpret_cast<ur_native_handle_t>(CLContext);
UR_RETURN_ON_FAILURE(urContextCreateWithNativeHandle(
NativeContext, 0, nullptr, nullptr, &Context));
}
return UR_RESULT_SUCCESS;
}

native_type get() { return Kernel; }
};
8 changes: 4 additions & 4 deletions source/adapters/opencl/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

static std::vector<std::unique_ptr<ur_platform_handle_t_>> URPlatforms;
static std::vector<ur_platform_handle_t> URPlatforms;
static std::once_flag InitFlag;
static uint32_t NumPlatforms = 0;
cl_int Result = CL_SUCCESS;
Expand All @@ -105,10 +105,10 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
if (Result != CL_SUCCESS) {
return Result;
}
URPlatforms.resize(NumPlatforms);
for (uint32_t i = 0; i < NumPlatforms; i++) {
URPlatforms[i] =
auto URPlatform =
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
URPlatforms.emplace_back(URPlatform.release());
}
return Result;
},
Expand All @@ -126,7 +126,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
}
if (NumEntries && phPlatforms) {
for (uint32_t i = 0; i < NumEntries; i++) {
phPlatforms[i] = URPlatforms[i].get();
phPlatforms[i] = URPlatforms[i];
}
}
return mapCLErrorToUR(Result);
Expand Down
1 change: 1 addition & 0 deletions source/adapters/opencl/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle(

auto URProgram =
std::make_unique<ur_program_handle_t_>(NativeHandle, hContext);
UR_RETURN_ON_FAILURE(URProgram->initWithNative());
*phProgram = URProgram.release();
if (!pProperties || !pProperties->isNativeHandleOwned) {
return urProgramRetain(*phProgram);
Expand Down
Loading

0 comments on commit f7b0de5

Please sign in to comment.