Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ur] Introduce virtual memory interfaces #525

Merged
merged 1 commit into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class ur_sampler_handle_t(c_void_p):
class ur_mem_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of physical memory object
class ur_physical_mem_handle_t(c_void_p):
pass

###############################################################################
## @brief Generic macro for enumerator bit masks
def UR_BIT( _i ):
Expand Down Expand Up @@ -234,6 +239,7 @@ class ur_structure_type_v(IntEnum):
EXP_COMMAND_BUFFER_DESC = 27 ## ::ur_exp_command_buffer_desc_t
EXP_SAMPLER_MIP_PROPERTIES = 28 ## ::ur_exp_sampler_mip_properties_t
KERNEL_ARG_MEM_OBJ_PROPERTIES = 29 ## ::ur_kernel_arg_mem_obj_properties_t
PHYSICAL_MEM_PROPERTIES = 30 ## ::ur_physical_mem_properties_t

class ur_structure_type_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -1304,6 +1310,59 @@ def __str__(self):
return str(ur_usm_pool_info_v(self.value))


###############################################################################
## @brief Virtual memory granularity info
class ur_virtual_mem_granularity_info_v(IntEnum):
MINIMUM = 0x30100 ## [size_t] size in bytes of the minimum virtual memory granularity.
RECOMMENDED = 0x30101 ## [size_t] size in bytes of the recommended virtual memory granularity.

class ur_virtual_mem_granularity_info_t(c_int):
def __str__(self):
return str(ur_virtual_mem_granularity_info_v(self.value))


###############################################################################
## @brief Virtual memory access mode flags.
class ur_virtual_mem_access_flags_v(IntEnum):
READ_WRITE = UR_BIT(0) ## Virtual memory both read and write accessible
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reaction to this, but I just realized that we are losing the ability to set the access mode to "none" here by defining READ_WRITE as 0. CUDA considers an access mode of 0 as revoking access. @kbenzie - Could we either have a NONE value here or offset the existing ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that! Yeah, I'll add none.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#638 will fix this.

READ_ONLY = UR_BIT(1) ##

class ur_virtual_mem_access_flags_t(c_int):
def __str__(self):
return hex(self.value)


###############################################################################
## @brief Virtual memory range info queries.
class ur_virtual_mem_info_v(IntEnum):
ACCESS_MODE = 0 ## [::ur_virtual_mem_access_flags_t] access flags of a mapped virtual
## memory range.

class ur_virtual_mem_info_t(c_int):
def __str__(self):
return str(ur_virtual_mem_info_v(self.value))


###############################################################################
## @brief Physical memory creation properties.
class ur_physical_mem_flags_v(IntEnum):
TBD = UR_BIT(0) ## reserved for future use.

class ur_physical_mem_flags_t(c_int):
def __str__(self):
return hex(self.value)


###############################################################################
## @brief Physical memory creation properties.
class ur_physical_mem_properties_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES
("pNext", c_void_p), ## [in,out][optional] pointer to extension-specific structure
("flags", ur_physical_mem_flags_t) ## [in] physical memory creation flags
]

###############################################################################
## @brief Program metadata property type.
class ur_program_metadata_type_v(IntEnum):
Expand Down Expand Up @@ -1902,6 +1961,16 @@ class ur_function_v(IntEnum):
PLATFORM_GET_LAST_ERROR = 150 ## Enumerator for ::urPlatformGetLastError
ENQUEUE_USM_FILL_2D = 151 ## Enumerator for ::urEnqueueUSMFill2D
ENQUEUE_USM_MEMCPY_2D = 152 ## Enumerator for ::urEnqueueUSMMemcpy2D
VIRTUAL_MEM_GRANULARITY_GET_INFO = 153 ## Enumerator for ::urVirtualMemGranularityGetInfo
VIRTUAL_MEM_RESERVE = 154 ## Enumerator for ::urVirtualMemReserve
VIRTUAL_MEM_FREE = 155 ## Enumerator for ::urVirtualMemFree
VIRTUAL_MEM_MAP = 156 ## Enumerator for ::urVirtualMemMap
VIRTUAL_MEM_UNMAP = 157 ## Enumerator for ::urVirtualMemUnmap
VIRTUAL_MEM_SET_ACCESS = 158 ## Enumerator for ::urVirtualMemSetAccess
VIRTUAL_MEM_GET_INFO = 159 ## Enumerator for ::urVirtualMemGetInfo
PHYSICAL_MEM_CREATE = 160 ## Enumerator for ::urPhysicalMemCreate
PHYSICAL_MEM_RETAIN = 161 ## Enumerator for ::urPhysicalMemRetain
PHYSICAL_MEM_RELEASE = 162 ## Enumerator for ::urPhysicalMemRelease

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -2586,6 +2655,37 @@ class ur_mem_dditable_t(Structure):
("pfnImageGetInfo", c_void_p) ## _urMemImageGetInfo_t
]

###############################################################################
## @brief Function-pointer for urPhysicalMemCreate
if __use_win_types:
_urPhysicalMemCreate_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, ur_device_handle_t, c_size_t, POINTER(ur_physical_mem_properties_t), POINTER(ur_physical_mem_handle_t) )
else:
_urPhysicalMemCreate_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, ur_device_handle_t, c_size_t, POINTER(ur_physical_mem_properties_t), POINTER(ur_physical_mem_handle_t) )

###############################################################################
## @brief Function-pointer for urPhysicalMemRetain
if __use_win_types:
_urPhysicalMemRetain_t = WINFUNCTYPE( ur_result_t, ur_physical_mem_handle_t )
else:
_urPhysicalMemRetain_t = CFUNCTYPE( ur_result_t, ur_physical_mem_handle_t )

###############################################################################
## @brief Function-pointer for urPhysicalMemRelease
if __use_win_types:
_urPhysicalMemRelease_t = WINFUNCTYPE( ur_result_t, ur_physical_mem_handle_t )
else:
_urPhysicalMemRelease_t = CFUNCTYPE( ur_result_t, ur_physical_mem_handle_t )


###############################################################################
## @brief Table of PhysicalMem functions pointers
class ur_physical_mem_dditable_t(Structure):
_fields_ = [
("pfnCreate", c_void_p), ## _urPhysicalMemCreate_t
("pfnRetain", c_void_p), ## _urPhysicalMemRetain_t
("pfnRelease", c_void_p) ## _urPhysicalMemRelease_t
]

###############################################################################
## @brief Function-pointer for urEnqueueKernelLaunch
if __use_win_types:
Expand Down Expand Up @@ -3203,6 +3303,69 @@ class ur_global_dditable_t(Structure):
("pfnTearDown", c_void_p) ## _urTearDown_t
]

###############################################################################
## @brief Function-pointer for urVirtualMemGranularityGetInfo
if __use_win_types:
_urVirtualMemGranularityGetInfo_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, ur_device_handle_t, ur_virtual_mem_granularity_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urVirtualMemGranularityGetInfo_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, ur_device_handle_t, ur_virtual_mem_granularity_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
kbenzie marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
## @brief Function-pointer for urVirtualMemReserve
if __use_win_types:
_urVirtualMemReserve_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, POINTER(c_void_p) )
else:
_urVirtualMemReserve_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, POINTER(c_void_p) )

###############################################################################
## @brief Function-pointer for urVirtualMemFree
if __use_win_types:
_urVirtualMemFree_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t )
else:
_urVirtualMemFree_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t )

###############################################################################
## @brief Function-pointer for urVirtualMemMap
if __use_win_types:
_urVirtualMemMap_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_physical_mem_handle_t, c_size_t, ur_virtual_mem_access_flags_t )
else:
_urVirtualMemMap_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_physical_mem_handle_t, c_size_t, ur_virtual_mem_access_flags_t )

###############################################################################
## @brief Function-pointer for urVirtualMemUnmap
if __use_win_types:
_urVirtualMemUnmap_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t )
else:
_urVirtualMemUnmap_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t )

###############################################################################
## @brief Function-pointer for urVirtualMemSetAccess
if __use_win_types:
_urVirtualMemSetAccess_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_virtual_mem_access_flags_t )
else:
_urVirtualMemSetAccess_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_virtual_mem_access_flags_t )

###############################################################################
## @brief Function-pointer for urVirtualMemGetInfo
if __use_win_types:
_urVirtualMemGetInfo_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_virtual_mem_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urVirtualMemGetInfo_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_void_p, c_size_t, ur_virtual_mem_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of VirtualMem functions pointers
class ur_virtual_mem_dditable_t(Structure):
_fields_ = [
("pfnGranularityGetInfo", c_void_p), ## _urVirtualMemGranularityGetInfo_t
("pfnReserve", c_void_p), ## _urVirtualMemReserve_t
("pfnFree", c_void_p), ## _urVirtualMemFree_t
("pfnMap", c_void_p), ## _urVirtualMemMap_t
("pfnUnmap", c_void_p), ## _urVirtualMemUnmap_t
("pfnSetAccess", c_void_p), ## _urVirtualMemSetAccess_t
("pfnGetInfo", c_void_p) ## _urVirtualMemGetInfo_t
]

###############################################################################
## @brief Function-pointer for urDeviceGet
if __use_win_types:
Expand Down Expand Up @@ -3292,13 +3455,15 @@ class ur_dditable_t(Structure):
("Kernel", ur_kernel_dditable_t),
("Sampler", ur_sampler_dditable_t),
("Mem", ur_mem_dditable_t),
("PhysicalMem", ur_physical_mem_dditable_t),
("Enqueue", ur_enqueue_dditable_t),
("Queue", ur_queue_dditable_t),
("BindlessImagesExp", ur_bindless_images_exp_dditable_t),
("USM", ur_usm_dditable_t),
("USMExp", ur_usm_exp_dditable_t),
("CommandBufferExp", ur_command_buffer_exp_dditable_t),
("Global", ur_global_dditable_t),
("VirtualMem", ur_virtual_mem_dditable_t),
("Device", ur_device_dditable_t)
]

Expand Down Expand Up @@ -3447,6 +3612,18 @@ def __init__(self, version : ur_api_version_t):
self.urMemGetInfo = _urMemGetInfo_t(self.__dditable.Mem.pfnGetInfo)
self.urMemImageGetInfo = _urMemImageGetInfo_t(self.__dditable.Mem.pfnImageGetInfo)

# call driver to get function pointers
PhysicalMem = ur_physical_mem_dditable_t()
r = ur_result_v(self.__dll.urGetPhysicalMemProcAddrTable(version, byref(PhysicalMem)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.PhysicalMem = PhysicalMem

# attach function interface to function address
self.urPhysicalMemCreate = _urPhysicalMemCreate_t(self.__dditable.PhysicalMem.pfnCreate)
self.urPhysicalMemRetain = _urPhysicalMemRetain_t(self.__dditable.PhysicalMem.pfnRetain)
self.urPhysicalMemRelease = _urPhysicalMemRelease_t(self.__dditable.PhysicalMem.pfnRelease)

# call driver to get function pointers
Enqueue = ur_enqueue_dditable_t()
r = ur_result_v(self.__dll.urGetEnqueueProcAddrTable(version, byref(Enqueue)))
Expand Down Expand Up @@ -3581,6 +3758,22 @@ def __init__(self, version : ur_api_version_t):
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)

# call driver to get function pointers
VirtualMem = ur_virtual_mem_dditable_t()
r = ur_result_v(self.__dll.urGetVirtualMemProcAddrTable(version, byref(VirtualMem)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.VirtualMem = VirtualMem

# attach function interface to function address
self.urVirtualMemGranularityGetInfo = _urVirtualMemGranularityGetInfo_t(self.__dditable.VirtualMem.pfnGranularityGetInfo)
self.urVirtualMemReserve = _urVirtualMemReserve_t(self.__dditable.VirtualMem.pfnReserve)
self.urVirtualMemFree = _urVirtualMemFree_t(self.__dditable.VirtualMem.pfnFree)
self.urVirtualMemMap = _urVirtualMemMap_t(self.__dditable.VirtualMem.pfnMap)
self.urVirtualMemUnmap = _urVirtualMemUnmap_t(self.__dditable.VirtualMem.pfnUnmap)
self.urVirtualMemSetAccess = _urVirtualMemSetAccess_t(self.__dditable.VirtualMem.pfnSetAccess)
self.urVirtualMemGetInfo = _urVirtualMemGetInfo_t(self.__dditable.VirtualMem.pfnGetInfo)

# call driver to get function pointers
Device = ur_device_dditable_t()
r = ur_result_v(self.__dll.urGetDeviceProcAddrTable(version, byref(Device)))
Expand Down
Loading