Skip to content

Commit

Permalink
[UR] Pull in fix that allows handling location properties in USM allo…
Browse files Browse the repository at this point in the history
…cs. (intel#12005)

Also handle translating these properties in pi2ur.

oneapi-src/unified-runtime#1123

---------

Co-authored-by: Kenneth Benzie (Benie) <k.benzie@codeplay.com>
  • Loading branch information
aarongreig and kbenzie authored Dec 11, 2023
1 parent 8074617 commit 7bd51c6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 32 deletions.
12 changes: 6 additions & 6 deletions sycl/plugins/unified_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ if(SYCL_PI_UR_USE_FETCH_CONTENT)
include(FetchContent)

set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
# commit e69ed21468e04ed6e832accf162422ed11736446
# Merge: 20fa0b5f 7fd9dafd
# commit 69a56ea6d1369a6bde5fce97c85fc7dbda49252f
# Merge: b25bb64d b78f541d
# Author: Kenneth Benzie (Benie) <k.benzie@codeplay.com>
# Date: Fri Dec 8 12:18:51 2023 +0000
# Merge pull request #962 from jandres742/fixwaitbarrierwithevent
# [UR][L0] Correctly wait on barrier on urEnqueueEventsWaitWithBarrier
set(UNIFIED_RUNTIME_TAG e69ed21468e04ed6e832accf162422ed11736446)
# Date: Mon Dec 11 12:30:24 2023 +0000
# Merge pull request #1123 from aarongreig/aaron/usmLocationProps
# [OpenCL] Add ur_usm_alloc_location_desc struct and handle it in the CL adapter.
set(UNIFIED_RUNTIME_TAG 69a56ea6d1369a6bde5fce97c85fc7dbda49252f)

if(SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO)
set(UNIFIED_RUNTIME_REPO "${SYCL_PI_UR_OVERRIDE_FETCH_CONTENT_REPO}")
Expand Down
99 changes: 73 additions & 26 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2697,12 +2697,28 @@ inline pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags,
inline pi_result piextUSMHostAlloc(void **ResultPtr, pi_context Context,
pi_usm_mem_properties *Properties,
size_t Size, pi_uint32 Alignment) {
ur_usm_desc_t USMDesc{};
USMDesc.align = Alignment;

ur_usm_alloc_location_desc_t UsmLocationDesc{};
UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;

if (Properties) {
uint32_t Next = 0;
while (Properties[Next]) {
if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) {
UsmLocationDesc.location = static_cast<uint32_t>(Properties[Next + 1]);
USMDesc.pNext = &UsmLocationDesc;
} else {
return PI_ERROR_INVALID_VALUE;
}
Next += 2;
}
}

std::ignore = Properties;
ur_context_handle_t UrContext =
reinterpret_cast<ur_context_handle_t>(Context);
ur_usm_desc_t USMDesc{};
USMDesc.align = Alignment;

ur_usm_pool_handle_t Pool{};
HANDLE_ERRORS(urUSMHostAlloc(UrContext, &USMDesc, Pool, Size, ResultPtr));
return PI_SUCCESS;
Expand Down Expand Up @@ -3131,14 +3147,29 @@ inline pi_result piextUSMDeviceAlloc(void **ResultPtr, pi_context Context,
pi_device Device,
pi_usm_mem_properties *Properties,
size_t Size, pi_uint32 Alignment) {

std::ignore = Properties;
ur_context_handle_t UrContext =
reinterpret_cast<ur_context_handle_t>(Context);
auto UrDevice = reinterpret_cast<ur_device_handle_t>(Device);

ur_usm_desc_t USMDesc{};
USMDesc.align = Alignment;

ur_usm_alloc_location_desc_t UsmLocDesc{};
UsmLocDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;

if (Properties) {
uint32_t Next = 0;
while (Properties[Next]) {
if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) {
UsmLocDesc.location = static_cast<uint32_t>(Properties[Next + 1]);
USMDesc.pNext = &UsmLocDesc;
} else {
return PI_ERROR_INVALID_VALUE;
}
Next += 2;
}
}

ur_usm_pool_handle_t Pool{};
HANDLE_ERRORS(
urUSMDeviceAlloc(UrContext, UrDevice, &USMDesc, Pool, Size, ResultPtr));
Expand Down Expand Up @@ -3171,42 +3202,58 @@ inline pi_result piextUSMSharedAlloc(void **ResultPtr, pi_context Context,
pi_device Device,
pi_usm_mem_properties *Properties,
size_t Size, pi_uint32 Alignment) {

std::ignore = Properties;
if (Properties && *Properties != 0) {
PI_ASSERT(*(Properties) == PI_MEM_ALLOC_FLAGS && *(Properties + 2) == 0,
PI_ERROR_INVALID_VALUE);
}

ur_context_handle_t UrContext =
reinterpret_cast<ur_context_handle_t>(Context);
auto UrDevice = reinterpret_cast<ur_device_handle_t>(Device);

ur_usm_desc_t USMDesc{};
USMDesc.align = Alignment;
ur_usm_device_desc_t UsmDeviceDesc{};
UsmDeviceDesc.stype = UR_STRUCTURE_TYPE_USM_DEVICE_DESC;
ur_usm_host_desc_t UsmHostDesc{};
UsmHostDesc.stype = UR_STRUCTURE_TYPE_USM_HOST_DESC;
ur_usm_alloc_location_desc_t UsmLocationDesc{};
UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;

// One properties bitfield can correspond to a host_desc and a device_desc
// struct, since having `0` values in these is harmless we can set up this
// pNext chain in advance.
USMDesc.pNext = &UsmDeviceDesc;
UsmDeviceDesc.pNext = &UsmHostDesc;

if (Properties) {
if (Properties[0] == PI_MEM_ALLOC_FLAGS) {
if (Properties[1] == PI_MEM_ALLOC_WRTITE_COMBINED) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED;
}
if (Properties[1] == PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT;
uint32_t Next = 0;
while (Properties[Next]) {
switch (Properties[Next]) {
case PI_MEM_ALLOC_FLAGS: {
if (Properties[Next + 1] & PI_MEM_ALLOC_WRTITE_COMBINED) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED;
}
if (Properties[Next + 1] & PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT;
}
if (Properties[Next + 1] & PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) {
UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT;
}
if (Properties[Next + 1] & PI_MEM_ALLOC_DEVICE_READ_ONLY) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY;
}
break;
}
if (Properties[1] == PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) {
UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT;
case PI_MEM_USM_ALLOC_BUFFER_LOCATION: {
UsmLocationDesc.location = static_cast<uint32_t>(Properties[Next + 1]);
// We wait until we've seen a BUFFER_LOCATION property to tack this
// onto the end of the chain, a `0` here might be valid as far as we
// know so we must exclude it unless we've been given a value.
UsmHostDesc.pNext = &UsmLocationDesc;
break;
}
if (Properties[1] == PI_MEM_ALLOC_DEVICE_READ_ONLY) {
UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY;
default:
return PI_ERROR_INVALID_VALUE;
}
Next += 2;
}
}
UsmDeviceDesc.pNext = &UsmHostDesc;
USMDesc.pNext = &UsmDeviceDesc;

USMDesc.align = Alignment;

ur_usm_pool_handle_t Pool{};
HANDLE_ERRORS(
Expand Down

0 comments on commit 7bd51c6

Please sign in to comment.