Skip to content

Commit

Permalink
[UR][L0][Bindless] Create sampled image from image and sampler descri…
Browse files Browse the repository at this point in the history
…ptors
  • Loading branch information
wenju-he committed Mar 14, 2024
1 parent 1cf9a08 commit 2865cd8
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 107 deletions.
195 changes: 88 additions & 107 deletions source/adapters/level_zero/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,6 @@ zeMemGetPitchFor2dImage_pfn zeMemGetPitchFor2dImageFunctionPtr = nullptr;

zeImageGetDeviceOffsetExp_pfn zeImageGetDeviceOffsetExpFunctionPtr = nullptr;

/// Return true if the two image_desc are the same.
bool isSameImageDesc(const ze_image_desc_t *Desc1,
const ze_image_desc_t *Desc2) {
auto IsSameImageFormat = [](const ze_image_format_t &Format1,
const ze_image_format_t &Format2) {
return Format1.layout == Format2.layout && Format1.type == Format2.type &&
Format1.x == Format2.x && Format1.y == Format2.y &&
Format1.z == Format2.z && Format1.w == Format2.w;
};
return Desc1->stype == Desc2->stype && Desc1->flags == Desc2->flags &&
Desc1->type == Desc2->type &&
IsSameImageFormat(Desc1->format, Desc2->format) &&
Desc1->width == Desc2->width && Desc1->height == Desc2->height &&
Desc1->depth == Desc2->depth &&
Desc1->arraylevels == Desc2->arraylevels &&
Desc1->miplevels == Desc2->miplevels;
}

/// Construct UR image format from ZE image desc.
ur_result_t ze2urImageFormat(const ze_image_desc_t *ZeImageDesc,
ur_image_format_t *UrImageFormat) {
Expand Down Expand Up @@ -425,6 +407,90 @@ uint32_t getPixelSizeBytes(const ur_image_format_t *Format) {
return NumChannels * ChannelTypeSizeInBytes;
}

ur_result_t bindlessImagesCreateImpl(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
ur_exp_image_mem_handle_t hImageMem, const ur_image_format_t *pImageFormat,
const ur_image_desc_t *pImageDesc, ur_sampler_handle_t hSampler, ur_mem_handle_t *phMem,
ur_exp_image_handle_t *phImage) {
std::shared_lock<ur_shared_mutex> Lock(hContext->Mutex);

UR_ASSERT(hContext && hDevice && hImageMem,
UR_RESULT_ERROR_INVALID_NULL_HANDLE);
UR_ASSERT(pImageFormat && pImageDesc && phMem && phImage,
UR_RESULT_ERROR_INVALID_NULL_POINTER);

ZeStruct<ze_image_desc_t> ZeImageDesc;
UR_CALL(ur2zeImageDesc(pImageFormat, pImageDesc, ZeImageDesc));

ze_image_bindless_exp_desc_t BindlessDesc;
ZeImageDesc.pNext = &BindlessDesc;

ZeStruct<ze_sampler_desc_t> ZeSamplerDesc;
if (hSampler) {
ZeSamplerDesc = hSampler->ZeSamplerDesc;
}

ze_image_handle_t ZeImage;

ze_memory_allocation_properties_t MemAllocProperties{
ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES};
ZE2UR_CALL(zeMemGetAllocProperties,
(hContext->ZeContext, hImageMem, &MemAllocProperties, nullptr));
if (MemAllocProperties.type == ZE_MEMORY_TYPE_UNKNOWN) {
_ur_image *UrImage = reinterpret_cast<_ur_image *>(hImageMem);

BindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC;
BindlessDesc.pNext = hSampler ? &ZeSamplerDesc : nullptr;
BindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS;

ZE2UR_CALL(zeImageViewCreateExt,
(hContext->ZeContext, hDevice->ZeDevice, &ZeImageDesc,
UrImage->ZeImage, &ZeImage));
ZE2UR_CALL(zeContextMakeImageResident,
(hContext->ZeContext, hDevice->ZeDevice, ZeImage));
UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true,
ZeImageDesc, phMem));
} else if (MemAllocProperties.type == ZE_MEMORY_TYPE_DEVICE) {
ze_image_pitched_exp_desc_t PitchedDesc;
PitchedDesc.stype = ZE_STRUCTURE_TYPE_PITCHED_IMAGE_EXP_DESC;
PitchedDesc.pNext = hSampler ? &ZeSamplerDesc : nullptr;
PitchedDesc.ptr = hImageMem;

BindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC;
BindlessDesc.pNext = &PitchedDesc;
BindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS;

ZE2UR_CALL(zeImageCreate, (hContext->ZeContext, hDevice->ZeDevice,
&ZeImageDesc, &ZeImage));
ZE2UR_CALL(zeContextMakeImageResident,
(hContext->ZeContext, hDevice->ZeDevice, ZeImage));
UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true,
ZeImageDesc, phMem));
} else {
return UR_RESULT_ERROR_INVALID_VALUE;
}

static std::once_flag InitFlag;
std::call_once(InitFlag, [&]() {
ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver;
auto Result = zeDriverGetExtensionFunctionAddress(
DriverHandle, "zeImageGetDeviceOffsetExp",
(void **)&zeImageGetDeviceOffsetExpFunctionPtr);
if (Result != ZE_RESULT_SUCCESS)
urPrint("zeDriverGetExtensionFunctionAddress zeImageGetDeviceOffsetExp "
"failed, err = %d\n",
Result);
});
if (!zeImageGetDeviceOffsetExpFunctionPtr)
return UR_RESULT_ERROR_INVALID_OPERATION;

uint64_t DeviceOffset{};
ZE2UR_CALL(zeImageGetDeviceOffsetExpFunctionPtr, (ZeImage, &DeviceOffset));
*phImage = reinterpret_cast<ur_exp_image_handle_t>(DeviceOffset);

return UR_RESULT_SUCCESS;
}

} // namespace

ur_result_t getImageRegionHelper(ze_image_desc_t ZeImageDesc,
Expand Down Expand Up @@ -649,82 +715,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp(
ur_exp_image_mem_handle_t hImageMem, const ur_image_format_t *pImageFormat,
const ur_image_desc_t *pImageDesc, ur_mem_handle_t *phMem,
ur_exp_image_handle_t *phImage) {
std::shared_lock<ur_shared_mutex> Lock(hContext->Mutex);

UR_ASSERT(hContext && hDevice && hImageMem,
UR_RESULT_ERROR_INVALID_NULL_HANDLE);
UR_ASSERT(pImageFormat && pImageDesc && phMem && phImage,
UR_RESULT_ERROR_INVALID_NULL_POINTER);

ZeStruct<ze_image_desc_t> ZeImageDesc;
UR_CALL(ur2zeImageDesc(pImageFormat, pImageDesc, ZeImageDesc));

ze_image_handle_t ZeImage;

ze_memory_allocation_properties_t MemAllocProperties{
ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES};
ZE2UR_CALL(zeMemGetAllocProperties,
(hContext->ZeContext, hImageMem, &MemAllocProperties, nullptr));
if (MemAllocProperties.type == ZE_MEMORY_TYPE_UNKNOWN) {
_ur_image *UrImage = reinterpret_cast<_ur_image *>(hImageMem);
if (!isSameImageDesc(&UrImage->ZeImageDesc, &ZeImageDesc)) {
ze_image_bindless_exp_desc_t ZeImageBindlessDesc;
ZeImageBindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC;
ZeImageBindlessDesc.pNext = nullptr;
ZeImageBindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS;
ZeImageDesc.pNext = &ZeImageBindlessDesc;
ZE2UR_CALL(zeImageViewCreateExt,
(hContext->ZeContext, hDevice->ZeDevice, &ZeImageDesc,
UrImage->ZeImage, &ZeImage));
ZE2UR_CALL(zeContextMakeImageResident,
(hContext->ZeContext, hDevice->ZeDevice, ZeImage));
UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true,
ZeImageDesc, phMem));
} else {
ZeImage = UrImage->ZeImage;
*phMem = nullptr;
}
} else if (MemAllocProperties.type == ZE_MEMORY_TYPE_DEVICE) {
ze_image_pitched_exp_desc_t PitchedDesc;
PitchedDesc.stype = ZE_STRUCTURE_TYPE_PITCHED_IMAGE_EXP_DESC;
PitchedDesc.pNext = nullptr;
PitchedDesc.ptr = hImageMem;

ze_image_bindless_exp_desc_t BindlessDesc;
BindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC;
BindlessDesc.pNext = &PitchedDesc;
BindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS;

ZeImageDesc.pNext = &BindlessDesc;

ZE2UR_CALL(zeImageCreate, (hContext->ZeContext, hDevice->ZeDevice,
&ZeImageDesc, &ZeImage));
ZE2UR_CALL(zeContextMakeImageResident,
(hContext->ZeContext, hDevice->ZeDevice, ZeImage));
UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true,
ZeImageDesc, phMem));
} else {
return UR_RESULT_ERROR_INVALID_VALUE;
}

static std::once_flag InitFlag;
std::call_once(InitFlag, [&]() {
ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver;
auto Result = zeDriverGetExtensionFunctionAddress(
DriverHandle, "zeImageGetDeviceOffsetExp",
(void **)&zeImageGetDeviceOffsetExpFunctionPtr);
if (Result != ZE_RESULT_SUCCESS)
urPrint("zeDriverGetExtensionFunctionAddress zeImageGetDeviceOffsetExp "
"failed, err = %d\n",
Result);
});
if (!zeImageGetDeviceOffsetExpFunctionPtr)
return UR_RESULT_ERROR_INVALID_OPERATION;

uint64_t DeviceOffset{};
ZE2UR_CALL(zeImageGetDeviceOffsetExpFunctionPtr, (ZeImage, &DeviceOffset));
*phImage = reinterpret_cast<ur_exp_image_handle_t>(DeviceOffset);

UR_CALL(bindlessImagesCreateImpl(
hContext, hDevice, hImageMem, pImageFormat, pImageDesc, nullptr, phMem, phImage));
return UR_RESULT_SUCCESS;
}

Expand All @@ -733,19 +725,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
ur_exp_image_mem_handle_t hImageMem, const ur_image_format_t *pImageFormat,
const ur_image_desc_t *pImageDesc, ur_sampler_handle_t hSampler,
ur_mem_handle_t *phMem, ur_exp_image_handle_t *phImage) {

UR_CALL(urBindlessImagesUnsampledImageCreateExp(
hContext, hDevice, hImageMem, pImageFormat, pImageDesc, phMem, phImage));

struct combined_sampled_image_handle {
uint64_t RawImageHandle;
uint64_t RawSamplerHandle;
};
auto *SampledImageHandle =
reinterpret_cast<combined_sampled_image_handle *>(phImage);
SampledImageHandle->RawSamplerHandle =
reinterpret_cast<uint64_t>(hSampler->ZeSampler);

UR_CALL(bindlessImagesCreateImpl(
hContext, hDevice, hImageMem, pImageFormat, pImageDesc, hSampler, phMem, phImage));
return UR_RESULT_SUCCESS;
}

Expand Down
1 change: 1 addition & 0 deletions source/adapters/level_zero/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urSamplerCreate(

try {
ur_sampler_handle_t_ *UrSampler = new ur_sampler_handle_t_(ZeSampler);
UrSampler->ZeSamplerDesc = ZeSamplerDesc;
*Sampler = reinterpret_cast<ur_sampler_handle_t>(UrSampler);
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ struct ur_sampler_handle_t_ : _ur_object {

// Level Zero sampler handle.
ze_sampler_handle_t ZeSampler;

ZeStruct<ze_sampler_desc_t> ZeSamplerDesc;
};

0 comments on commit 2865cd8

Please sign in to comment.