Skip to content

Commit

Permalink
[SPEC] Add urProgramGetGlobalVariablePointer entrypoint
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiomestre committed Jan 16, 2024
1 parent c63ad9b commit f5e3493
Show file tree
Hide file tree
Showing 30 changed files with 601 additions and 10 deletions.
49 changes: 49 additions & 0 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ typedef enum ur_function_t {
UR_FUNCTION_COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 213, ///< Enumerator for ::urCommandBufferAppendUSMAdviseExp
UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 214, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 215, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER = 216, ///< Enumerator for ::urProgramGetGlobalVariablePointer
/// @cond
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand Down Expand Up @@ -4262,6 +4263,42 @@ urProgramGetFunctionPointer(
void **ppFunctionPointer ///< [out] Returns the pointer to the function if it is found in the program.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Retrieves a pointer to a device global variable.
///
/// @details
/// - Retrieves a pointer to a device global variable.
/// - The application may call this function from simultaneous threads for
/// the same device.
/// - The implementation of this function should be thread-safe.
///
/// @remarks
/// _Analogues_
/// - **clGetDeviceGlobalVariablePointerINTEL**
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hDevice`
/// + `NULL == hProgram`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pGlobalVariableName`
/// + `NULL == ppGlobalVariablePointerRet`
/// - ::UR_RESULT_ERROR_INVALID_SIZE
/// + `name` is not a valid variable in the program.
UR_APIEXPORT ur_result_t UR_APICALL
urProgramGetGlobalVariablePointer(
ur_device_handle_t hDevice, ///< [in] handle of the device to retrieve the pointer for.
ur_program_handle_t hProgram, ///< [in] handle of the program where the global variable is.
const char *pGlobalVariableName, ///< [in] mangled name of the global variable to retrieve the pointer for.
size_t *pGlobalVariableSizeRet, ///< [out][optional] Returns the size of the global variable if it is found
///< in the program.
void **ppGlobalVariablePointerRet ///< [out] Returns the pointer to the global variable if it is found in the program.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Get Program object information
typedef enum ur_program_info_t {
Expand Down Expand Up @@ -9144,6 +9181,18 @@ typedef struct ur_program_get_function_pointer_params_t {
void ***pppFunctionPointer;
} ur_program_get_function_pointer_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramGetGlobalVariablePointer
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_get_global_variable_pointer_params_t {
ur_device_handle_t *phDevice;
ur_program_handle_t *phProgram;
const char **ppGlobalVariableName;
size_t **ppGlobalVariableSizeRet;
void ***pppGlobalVariablePointerRet;
} ur_program_get_global_variable_pointer_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramGetInfo
/// @details Each entry is a pointer to the parameter passed to the function;
Expand Down
10 changes: 10 additions & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ typedef ur_result_t(UR_APICALL *ur_pfnProgramGetFunctionPointer_t)(
const char *,
void **);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramGetGlobalVariablePointer
typedef ur_result_t(UR_APICALL *ur_pfnProgramGetGlobalVariablePointer_t)(
ur_device_handle_t,
ur_program_handle_t,
const char *,
size_t *,
void **);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramGetInfo
typedef ur_result_t(UR_APICALL *ur_pfnProgramGetInfo_t)(
Expand Down Expand Up @@ -380,6 +389,7 @@ typedef struct ur_program_dditable_t {
ur_pfnProgramRetain_t pfnRetain;
ur_pfnProgramRelease_t pfnRelease;
ur_pfnProgramGetFunctionPointer_t pfnGetFunctionPointer;
ur_pfnProgramGetGlobalVariablePointer_t pfnGetGlobalVariablePointer;
ur_pfnProgramGetInfo_t pfnGetInfo;
ur_pfnProgramGetBuildInfo_t pfnGetBuildInfo;
ur_pfnProgramSetSpecializationConstants_t pfnSetSpecializationConstants;
Expand Down
44 changes: 44 additions & 0 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,9 @@ inline std::ostream &operator<<(std::ostream &os, ur_function_t value) {
case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP:
os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP";
break;
case UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER:
os << "UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER";
break;
default:
os << "unknown enumerator";
break;
Expand Down Expand Up @@ -10257,6 +10260,44 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
return os;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ur_program_get_global_variable_pointer_params_t type
/// @returns
/// std::ostream &
inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_program_get_global_variable_pointer_params_t *params) {

os << ".hDevice = ";

ur::details::printPtr(os,
*(params->phDevice));

os << ", ";
os << ".hProgram = ";

ur::details::printPtr(os,
*(params->phProgram));

os << ", ";
os << ".pGlobalVariableName = ";

ur::details::printPtr(os,
*(params->ppGlobalVariableName));

os << ", ";
os << ".pGlobalVariableSizeRet = ";

ur::details::printPtr(os,
*(params->ppGlobalVariableSizeRet));

os << ", ";
os << ".ppGlobalVariablePointerRet = ";

ur::details::printPtr(os,
*(params->pppGlobalVariablePointerRet));

return os;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ur_program_get_info_params_t type
/// @returns
Expand Down Expand Up @@ -15992,6 +16033,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_
case UR_FUNCTION_PROGRAM_GET_FUNCTION_POINTER: {
os << (const struct ur_program_get_function_pointer_params_t *)params;
} break;
case UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER: {
os << (const struct ur_program_get_global_variable_pointer_params_t *)params;
} break;
case UR_FUNCTION_PROGRAM_GET_INFO: {
os << (const struct ur_program_get_info_params_t *)params;
} break;
Expand Down
37 changes: 37 additions & 0 deletions scripts/core/program.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,43 @@ params:
desc: |
[out] Returns the pointer to the function if it is found in the program.
--- #--------------------------------------------------------------------------
type: function
desc: "Retrieves a pointer to a device global variable."
class: $xProgram
name: GetGlobalVariablePointer
decl: static
ordinal: "7"
analogue:
- "**clGetDeviceGlobalVariablePointerINTEL**"
details:
- "Retrieves a pointer to a device global variable."
- "The application may call this function from simultaneous threads for the same device."
- "The implementation of this function should be thread-safe."
params:
- type: "$x_device_handle_t"
name: hDevice
desc: |
[in] handle of the device to retrieve the pointer for.
- type: "$x_program_handle_t"
name: hProgram
desc: |
[in] handle of the program where the global variable is.
- type: "const char*"
name: pGlobalVariableName
desc: |
[in] mangled name of the global variable to retrieve the pointer for.
- type: "size_t*"
name: pGlobalVariableSizeRet
desc: |
[out][optional] Returns the size of the global variable if it is found in the program.
- type: "void**"
name: ppGlobalVariablePointerRet
desc: |
[out] Returns the pointer to the global variable if it is found in the program.
returns:
- $X_RESULT_ERROR_INVALID_SIZE:
- "`name` is not a valid variable in the program."
--- #--------------------------------------------------------------------------
type: enum
desc: "Get Program object information"
class: $xProgram
Expand Down
3 changes: 3 additions & 0 deletions scripts/core/registry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,9 @@ etors:
- name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP
desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCountExp
value: '215'
- name: PROGRAM_GET_GLOBAL_VARIABLE_POINTER
desc: Enumerator for $xProgramGetGlobalVariablePointer
value: '216'
---
type: enum
desc: Defines structure types
Expand Down
31 changes: 31 additions & 0 deletions source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
ur_device_handle_t, ur_program_handle_t hProgram,
const char *pGlobalVariableName, size_t *pGlobalVariableSizeRet,
void **ppGlobalVariablePointerRet) {

/* Since CUDA requires a global variable to be referenced by name, we use
* metadata to find the correct name to access it by. */
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(pGlobalVariableName);
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
return UR_RESULT_ERROR_INVALID_VALUE;
std::string DeviceGlobalName = DeviceGlobalNameIt->second;

ur_result_t Result = UR_RESULT_SUCCESS;
try {
CUdeviceptr DeviceGlobal = 0;
size_t DeviceGlobalSize = 0;
UR_CHECK_ERROR(cuModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
hProgram->get(),
DeviceGlobalName.c_str()));

if (pGlobalVariableSizeRet) {
*pGlobalVariableSizeRet = DeviceGlobalSize;
}
*ppGlobalVariablePointerRet = reinterpret_cast<void *>(DeviceGlobal);

} catch (ur_result_t Err) {
Result = Err;
}
return Result;
}
1 change: 1 addition & 0 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(
pDdiTable->pfnCreateWithNativeHandle = urProgramCreateWithNativeHandle;
pDdiTable->pfnGetBuildInfo = urProgramGetBuildInfo;
pDdiTable->pfnGetFunctionPointer = urProgramGetFunctionPointer;
pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer;
pDdiTable->pfnGetInfo = urProgramGetInfo;
pDdiTable->pfnGetNativeHandle = urProgramGetNativeHandle;
pDdiTable->pfnLink = urProgramLink;
Expand Down
5 changes: 5 additions & 0 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
ur_device_handle_t, ur_program_handle_t, const char *, size_t *, void **) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
1 change: 1 addition & 0 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(
pDdiTable->pfnCreateWithNativeHandle = urProgramCreateWithNativeHandle;
pDdiTable->pfnGetBuildInfo = urProgramGetBuildInfo;
pDdiTable->pfnGetFunctionPointer = urProgramGetFunctionPointer;
pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer;
pDdiTable->pfnGetInfo = urProgramGetInfo;
pDdiTable->pfnGetNativeHandle = urProgramGetNativeHandle;
pDdiTable->pfnLink = urProgramLink;
Expand Down
5 changes: 5 additions & 0 deletions source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
return ze2urResult(ZeResult);
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
ur_device_handle_t, ur_program_handle_t, const char *, size_t *, void **) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
ur_program_handle_t Program, ///< [in] handle of the Program object
ur_program_info_t PropName, ///< [in] name of the Program property to query
Expand Down
1 change: 1 addition & 0 deletions source/adapters/level_zero/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(
pDdiTable->pfnRetain = urProgramRetain;
pDdiTable->pfnRelease = urProgramRelease;
pDdiTable->pfnGetFunctionPointer = urProgramGetFunctionPointer;
pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer;
pDdiTable->pfnGetInfo = urProgramGetInfo;
pDdiTable->pfnGetBuildInfo = urProgramGetBuildInfo;
pDdiTable->pfnSetSpecializationConstants =
Expand Down
36 changes: 36 additions & 0 deletions source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,39 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer(
return exceptionToResult(std::current_exception());
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urProgramGetGlobalVariablePointer
__urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
ur_device_handle_t
hDevice, ///< [in] handle of the device to retrieve the pointer for.
ur_program_handle_t
hProgram, ///< [in] handle of the program where the global variable is.
const char *
pGlobalVariableName, ///< [in] mangled name of the global variable to retrieve the pointer for.
size_t *
pGlobalVariableSizeRet, ///< [out][optional] Returns the size of the global variable if it is found
///< in the program.
void **
ppGlobalVariablePointerRet ///< [out] Returns the pointer to the global variable if it is found in the program.
) try {
ur_result_t result = UR_RESULT_SUCCESS;

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnGetGlobalVariablePointer =
d_context.urDdiTable.Program.pfnGetGlobalVariablePointer;
if (nullptr != pfnGetGlobalVariablePointer) {
result = pfnGetGlobalVariablePointer(
hDevice, hProgram, pGlobalVariableName, pGlobalVariableSizeRet,
ppGlobalVariablePointerRet);
} else {
// generic implementation
}

return result;
} catch (...) {
return exceptionToResult(std::current_exception());
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urProgramGetInfo
__urdlllocal ur_result_t UR_APICALL urProgramGetInfo(
Expand Down Expand Up @@ -5941,6 +5974,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(

pDdiTable->pfnGetFunctionPointer = driver::urProgramGetFunctionPointer;

pDdiTable->pfnGetGlobalVariablePointer =
driver::urProgramGetGlobalVariablePointer;

pDdiTable->pfnGetInfo = driver::urProgramGetInfo;

pDdiTable->pfnGetBuildInfo = driver::urProgramGetBuildInfo;
Expand Down
10 changes: 9 additions & 1 deletion source/adapters/opencl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
* error is mapped to UR
*/
#define CL_RETURN_ON_FAILURE_AND_SET_NULL(clCall, outPtr) \
if (const cl_int cl_result_macro = clCall != CL_SUCCESS) { \
if (const cl_int cl_result_macro = clCall; cl_result_macro != CL_SUCCESS) { \
if (outPtr != nullptr) { \
*outPtr = nullptr; \
} \
Expand Down Expand Up @@ -197,6 +197,8 @@ CONSTFIX char SetProgramSpecializationConstantName[] =
"clSetProgramSpecializationConstant";
CONSTFIX char GetDeviceFunctionPointerName[] =
"clGetDeviceFunctionPointerINTEL";
CONSTFIX char GetDeviceGlobalVariablePointerName[] =
"clGetDeviceGlobalVariablePointerINTEL";
CONSTFIX char EnqueueWriteGlobalVariableName[] =
"clEnqueueWriteGlobalVariableINTEL";
CONSTFIX char EnqueueReadGlobalVariableName[] =
Expand All @@ -221,6 +223,10 @@ using clGetDeviceFunctionPointer_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_device_id device, cl_program program,
const char *FuncName, cl_ulong *ret_ptr);

using clGetDeviceGlobalVariablePointer_fn = CL_API_ENTRY cl_int(CL_API_CALL *)(
cl_device_id device, cl_program program, const char *globalVariableName,
size_t *globalVariableSizeRet, void **globalVariablePointerRet);

using clEnqueueWriteGlobalVariable_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_command_queue, cl_program, const char *, cl_bool,
size_t, size_t, const void *, cl_uint, const cl_event *,
Expand Down Expand Up @@ -314,6 +320,8 @@ struct ExtFuncPtrCacheT {
FuncPtrCache<clDeviceMemAllocINTEL_fn> clDeviceMemAllocINTELCache;
FuncPtrCache<clSharedMemAllocINTEL_fn> clSharedMemAllocINTELCache;
FuncPtrCache<clGetDeviceFunctionPointer_fn> clGetDeviceFunctionPointerCache;
FuncPtrCache<clGetDeviceGlobalVariablePointer_fn>
clGetDeviceGlobalVariablePointerCache;
FuncPtrCache<clCreateBufferWithPropertiesINTEL_fn>
clCreateBufferWithPropertiesINTELCache;
FuncPtrCache<clMemBlockingFreeINTEL_fn> clMemBlockingFreeINTELCache;
Expand Down
Loading

0 comments on commit f5e3493

Please sign in to comment.