Skip to content

Commit

Permalink
Merge pull request #1255 from fabiomestre/fabio/add_global_variable_p…
Browse files Browse the repository at this point in the history
…ointer

[SPEC] Add urProgramGetGlobalVariablePointer entrypoint
  • Loading branch information
kbenzie authored Mar 18, 2024
2 parents 29ee45c + ca3da5a commit 4d0183a
Show file tree
Hide file tree
Showing 39 changed files with 723 additions and 80 deletions.
61 changes: 55 additions & 6 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ typedef enum ur_function_t {
UR_FUNCTION_ADAPTER_RETAIN = 179, ///< Enumerator for ::urAdapterRetain
UR_FUNCTION_ADAPTER_GET_LAST_ERROR = 180, ///< Enumerator for ::urAdapterGetLastError
UR_FUNCTION_ADAPTER_GET_INFO = 181, ///< Enumerator for ::urAdapterGetInfo
UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP = 182, ///< Enumerator for ::urCommandBufferUpdateKernelLaunchExp
UR_FUNCTION_PROGRAM_BUILD_EXP = 197, ///< Enumerator for ::urProgramBuildExp
UR_FUNCTION_PROGRAM_COMPILE_EXP = 198, ///< Enumerator for ::urProgramCompileExp
UR_FUNCTION_PROGRAM_LINK_EXP = 199, ///< Enumerator for ::urProgramLinkExp
Expand All @@ -216,11 +215,13 @@ typedef enum ur_function_t {
UR_FUNCTION_COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 213, ///< Enumerator for ::urCommandBufferAppendUSMAdviseExp
UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 214, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 215, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
UR_FUNCTION_COMMAND_BUFFER_RETAIN_COMMAND_EXP = 216, ///< Enumerator for ::urCommandBufferRetainCommandExp
UR_FUNCTION_COMMAND_BUFFER_RELEASE_COMMAND_EXP = 217, ///< Enumerator for ::urCommandBufferReleaseCommandExp
UR_FUNCTION_COMMAND_BUFFER_GET_INFO_EXP = 218, ///< Enumerator for ::urCommandBufferGetInfoExp
UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP = 219, ///< Enumerator for ::urCommandBufferCommandGetInfoExp
UR_FUNCTION_DEVICE_GET_SELECTED = 220, ///< Enumerator for ::urDeviceGetSelected
UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER = 216, ///< Enumerator for ::urProgramGetGlobalVariablePointer
UR_FUNCTION_DEVICE_GET_SELECTED = 217, ///< Enumerator for ::urDeviceGetSelected
UR_FUNCTION_COMMAND_BUFFER_RETAIN_COMMAND_EXP = 218, ///< Enumerator for ::urCommandBufferRetainCommandExp
UR_FUNCTION_COMMAND_BUFFER_RELEASE_COMMAND_EXP = 219, ///< Enumerator for ::urCommandBufferReleaseCommandExp
UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP = 220, ///< Enumerator for ::urCommandBufferUpdateKernelLaunchExp
UR_FUNCTION_COMMAND_BUFFER_GET_INFO_EXP = 221, ///< Enumerator for ::urCommandBufferGetInfoExp
UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP = 222, ///< Enumerator for ::urCommandBufferCommandGetInfoExp
/// @cond
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand Down Expand Up @@ -4327,6 +4328,42 @@ urProgramGetFunctionPointer(
void **ppFunctionPointer ///< [out] Returns the pointer to the function if it is found in the program.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Retrieves a pointer to a device global variable.
///
/// @details
/// - Retrieves a pointer to a device global variable.
/// - The application may call this function from simultaneous threads for
/// the same device.
/// - The implementation of this function should be thread-safe.
///
/// @remarks
/// _Analogues_
/// - **clGetDeviceGlobalVariablePointerINTEL**
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hDevice`
/// + `NULL == hProgram`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pGlobalVariableName`
/// + `NULL == ppGlobalVariablePointerRet`
/// - ::UR_RESULT_ERROR_INVALID_VALUE
/// + `name` is not a valid variable in the program.
UR_APIEXPORT ur_result_t UR_APICALL
urProgramGetGlobalVariablePointer(
ur_device_handle_t hDevice, ///< [in] handle of the device to retrieve the pointer for.
ur_program_handle_t hProgram, ///< [in] handle of the program where the global variable is.
const char *pGlobalVariableName, ///< [in] mangled name of the global variable to retrieve the pointer for.
size_t *pGlobalVariableSizeRet, ///< [out][optional] Returns the size of the global variable if it is found
///< in the program.
void **ppGlobalVariablePointerRet ///< [out] Returns the pointer to the global variable if it is found in the program.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Get Program object information
typedef enum ur_program_info_t {
Expand Down Expand Up @@ -9470,6 +9507,18 @@ typedef struct ur_program_get_function_pointer_params_t {
void ***pppFunctionPointer;
} ur_program_get_function_pointer_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramGetGlobalVariablePointer
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_get_global_variable_pointer_params_t {
ur_device_handle_t *phDevice;
ur_program_handle_t *phProgram;
const char **ppGlobalVariableName;
size_t **ppGlobalVariableSizeRet;
void ***pppGlobalVariablePointerRet;
} ur_program_get_global_variable_pointer_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramGetInfo
/// @details Each entry is a pointer to the parameter passed to the function;
Expand Down
10 changes: 10 additions & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ typedef ur_result_t(UR_APICALL *ur_pfnProgramGetFunctionPointer_t)(
const char *,
void **);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramGetGlobalVariablePointer
typedef ur_result_t(UR_APICALL *ur_pfnProgramGetGlobalVariablePointer_t)(
ur_device_handle_t,
ur_program_handle_t,
const char *,
size_t *,
void **);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramGetInfo
typedef ur_result_t(UR_APICALL *ur_pfnProgramGetInfo_t)(
Expand Down Expand Up @@ -380,6 +389,7 @@ typedef struct ur_program_dditable_t {
ur_pfnProgramRetain_t pfnRetain;
ur_pfnProgramRelease_t pfnRelease;
ur_pfnProgramGetFunctionPointer_t pfnGetFunctionPointer;
ur_pfnProgramGetGlobalVariablePointer_t pfnGetGlobalVariablePointer;
ur_pfnProgramGetInfo_t pfnGetInfo;
ur_pfnProgramGetBuildInfo_t pfnGetBuildInfo;
ur_pfnProgramSetSpecializationConstants_t pfnSetSpecializationConstants;
Expand Down
8 changes: 8 additions & 0 deletions include/ur_print.h
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintProgramReleaseParams(const struct ur_
/// - `buff_size < out_size`
UR_APIEXPORT ur_result_t UR_APICALL urPrintProgramGetFunctionPointerParams(const struct ur_program_get_function_pointer_params_t *params, char *buffer, const size_t buff_size, size_t *out_size);

///////////////////////////////////////////////////////////////////////////////
/// @brief Print ur_program_get_global_variable_pointer_params_t struct
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_INVALID_SIZE
/// - `buff_size < out_size`
UR_APIEXPORT ur_result_t UR_APICALL urPrintProgramGetGlobalVariablePointerParams(const struct ur_program_get_global_variable_pointer_params_t *params, char *buffer, const size_t buff_size, size_t *out_size);

///////////////////////////////////////////////////////////////////////////////
/// @brief Print ur_program_get_info_params_t struct
/// @returns
Expand Down
56 changes: 50 additions & 6 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,9 +837,6 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) {
case UR_FUNCTION_ADAPTER_GET_INFO:
os << "UR_FUNCTION_ADAPTER_GET_INFO";
break;
case UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP";
break;
case UR_FUNCTION_PROGRAM_BUILD_EXP:
os << "UR_FUNCTION_PROGRAM_BUILD_EXP";
break;
Expand Down Expand Up @@ -897,21 +894,27 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) {
case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP:
os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP";
break;
case UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER:
os << "UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER";
break;
case UR_FUNCTION_DEVICE_GET_SELECTED:
os << "UR_FUNCTION_DEVICE_GET_SELECTED";
break;
case UR_FUNCTION_COMMAND_BUFFER_RETAIN_COMMAND_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_RETAIN_COMMAND_EXP";
break;
case UR_FUNCTION_COMMAND_BUFFER_RELEASE_COMMAND_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_RELEASE_COMMAND_EXP";
break;
case UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP";
break;
case UR_FUNCTION_COMMAND_BUFFER_GET_INFO_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_GET_INFO_EXP";
break;
case UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP:
os << "UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP";
break;
case UR_FUNCTION_DEVICE_GET_SELECTED:
os << "UR_FUNCTION_DEVICE_GET_SELECTED";
break;
default:
os << "unknown enumerator";
break;
Expand Down Expand Up @@ -10791,6 +10794,44 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
return os;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ur_program_get_global_variable_pointer_params_t type
/// @returns
/// std::ostream &
inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_program_get_global_variable_pointer_params_t *params) {

os << ".hDevice = ";

ur::details::printPtr(os,
*(params->phDevice));

os << ", ";
os << ".hProgram = ";

ur::details::printPtr(os,
*(params->phProgram));

os << ", ";
os << ".pGlobalVariableName = ";

ur::details::printPtr(os,
*(params->ppGlobalVariableName));

os << ", ";
os << ".pGlobalVariableSizeRet = ";

ur::details::printPtr(os,
*(params->ppGlobalVariableSizeRet));

os << ", ";
os << ".ppGlobalVariablePointerRet = ";

ur::details::printPtr(os,
*(params->pppGlobalVariablePointerRet));

return os;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ur_program_get_info_params_t type
/// @returns
Expand Down Expand Up @@ -16706,6 +16747,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_
case UR_FUNCTION_PROGRAM_GET_FUNCTION_POINTER: {
os << (const struct ur_program_get_function_pointer_params_t *)params;
} break;
case UR_FUNCTION_PROGRAM_GET_GLOBAL_VARIABLE_POINTER: {
os << (const struct ur_program_get_global_variable_pointer_params_t *)params;
} break;
case UR_FUNCTION_PROGRAM_GET_INFO: {
os << (const struct ur_program_get_info_params_t *)params;
} break;
Expand Down
37 changes: 37 additions & 0 deletions scripts/core/program.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,43 @@ params:
desc: |
[out] Returns the pointer to the function if it is found in the program.
--- #--------------------------------------------------------------------------
type: function
desc: "Retrieves a pointer to a device global variable."
class: $xProgram
name: GetGlobalVariablePointer
decl: static
ordinal: "7"
analogue:
- "**clGetDeviceGlobalVariablePointerINTEL**"
details:
- "Retrieves a pointer to a device global variable."
- "The application may call this function from simultaneous threads for the same device."
- "The implementation of this function should be thread-safe."
params:
- type: "$x_device_handle_t"
name: hDevice
desc: |
[in] handle of the device to retrieve the pointer for.
- type: "$x_program_handle_t"
name: hProgram
desc: |
[in] handle of the program where the global variable is.
- type: "const char*"
name: pGlobalVariableName
desc: |
[in] mangled name of the global variable to retrieve the pointer for.
- type: "size_t*"
name: pGlobalVariableSizeRet
desc: |
[out][optional] Returns the size of the global variable if it is found in the program.
- type: "void**"
name: ppGlobalVariablePointerRet
desc: |
[out] Returns the pointer to the global variable if it is found in the program.
returns:
- $X_RESULT_ERROR_INVALID_VALUE:
- "`name` is not a valid variable in the program."
--- #--------------------------------------------------------------------------
type: enum
desc: "Get Program object information"
class: $xProgram
Expand Down
23 changes: 13 additions & 10 deletions scripts/core/registry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,6 @@ etors:
- name: ADAPTER_GET_INFO
desc: Enumerator for $xAdapterGetInfo
value: '181'
- name: COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP
desc: Enumerator for $xCommandBufferUpdateKernelLaunchExp
value: '182'
- name: PROGRAM_BUILD_EXP
desc: Enumerator for $xProgramBuildExp
value: '197'
Expand Down Expand Up @@ -562,21 +559,27 @@ etors:
- name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP
desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCountExp
value: '215'
- name: PROGRAM_GET_GLOBAL_VARIABLE_POINTER
desc: Enumerator for $xProgramGetGlobalVariablePointer
value: '216'
- name: DEVICE_GET_SELECTED
desc: Enumerator for $xDeviceGetSelected
value: '217'
- name: COMMAND_BUFFER_RETAIN_COMMAND_EXP
desc: Enumerator for $xCommandBufferRetainCommandExp
value: '216'
value: '218'
- name: COMMAND_BUFFER_RELEASE_COMMAND_EXP
desc: Enumerator for $xCommandBufferReleaseCommandExp
value: '217'
value: '219'
- name: COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP
desc: Enumerator for $xCommandBufferUpdateKernelLaunchExp
value: '220'
- name: COMMAND_BUFFER_GET_INFO_EXP
desc: Enumerator for $xCommandBufferGetInfoExp
value: '218'
value: '221'
- name: COMMAND_BUFFER_COMMAND_GET_INFO_EXP
desc: Enumerator for $xCommandBufferCommandGetInfoExp
value: '219'
- name: DEVICE_GET_SELECTED
desc: Enumerator for $xDeviceGetSelected
value: '220'
value: '222'
---
type: enum
desc: Defines structure types
Expand Down
32 changes: 6 additions & 26 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,20 +1660,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
// Since CUDA requires a the global variable to be referenced by name, we use
// metadata to find the correct name to access it by.
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
return UR_RESULT_ERROR_INVALID_VALUE;
std::string DeviceGlobalName = DeviceGlobalNameIt->second;

ur_result_t Result = UR_RESULT_SUCCESS;
try {
CUdeviceptr DeviceGlobal = 0;
size_t DeviceGlobalSize = 0;
UR_CHECK_ERROR(cuModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
hProgram->get(),
DeviceGlobalName.c_str()));
UR_CHECK_ERROR(hProgram->getGlobalVariablePointer(name, &DeviceGlobal,
&DeviceGlobalSize));

if (offset + count > DeviceGlobalSize)
return UR_RESULT_ERROR_INVALID_VALUE;
Expand All @@ -1682,30 +1673,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
hQueue, blockingWrite, reinterpret_cast<void *>(DeviceGlobal + offset),
pSrc, count, numEventsInWaitList, phEventWaitList, phEvent);
} catch (ur_result_t Err) {
Result = Err;
return Err;
}
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blockingRead, size_t count, size_t offset, void *pDst,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
// Since CUDA requires a the global variable to be referenced by name, we use
// metadata to find the correct name to access it by.
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
return UR_RESULT_ERROR_INVALID_VALUE;
std::string DeviceGlobalName = DeviceGlobalNameIt->second;

ur_result_t Result = UR_RESULT_SUCCESS;
try {
CUdeviceptr DeviceGlobal = 0;
size_t DeviceGlobalSize = 0;
UR_CHECK_ERROR(cuModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
hProgram->get(),
DeviceGlobalName.c_str()));
UR_CHECK_ERROR(hProgram->getGlobalVariablePointer(name, &DeviceGlobal,
&DeviceGlobalSize));

if (offset + count > DeviceGlobalSize)
return UR_RESULT_ERROR_INVALID_VALUE;
Expand All @@ -1715,9 +1696,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
reinterpret_cast<const void *>(DeviceGlobal + offset), count,
numEventsInWaitList, phEventWaitList, phEvent);
} catch (ur_result_t Err) {
Result = Err;
return Err;
}
return Result;
}

/// Host Pipes
Expand Down
Loading

0 comments on commit 4d0183a

Please sign in to comment.