diff --git a/source/adapters/level_zero/CMakeLists.txt b/source/adapters/level_zero/CMakeLists.txt index adde1ec7d6..62ba260dea 100644 --- a/source/adapters/level_zero/CMakeLists.txt +++ b/source/adapters/level_zero/CMakeLists.txt @@ -34,7 +34,7 @@ if (NOT DEFINED LEVEL_ZERO_LIBRARY OR NOT DEFINED LEVEL_ZERO_INCLUDE_DIR) endif() set(LEVEL_ZERO_LOADER_REPO "https://github.com/oneapi-src/level-zero.git") - set(LEVEL_ZERO_LOADER_TAG v1.15.1) + set(LEVEL_ZERO_LOADER_TAG v1.16.1) # Disable due to a bug https://github.com/oneapi-src/level-zero/issues/104 set(CMAKE_INCLUDE_CURRENT_DIR OFF) diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index 918b04400a..2f6b3a91ff 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -919,6 +919,31 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo( return ReturnValue(true); case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: return ReturnValue(false); + case UR_DEVICE_INFO_BINDLESS_IMAGES_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_BINDLESS_IMAGES_SHARED_USM_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_BINDLESS_IMAGES_1D_USM_SUPPORT_EXP: + return ReturnValue(false); + case UR_DEVICE_INFO_BINDLESS_IMAGES_2D_USM_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_IMAGE_PITCH_ALIGN_EXP: + case UR_DEVICE_INFO_MAX_IMAGE_LINEAR_WIDTH_EXP: + case UR_DEVICE_INFO_MAX_IMAGE_LINEAR_HEIGHT_EXP: + case UR_DEVICE_INFO_MAX_IMAGE_LINEAR_PITCH_EXP: + urPrint("Unsupported ParamName in urGetDeviceInfo\n"); + urPrint("ParamName=%d(0x%x)\n", ParamName, ParamName); + return UR_RESULT_ERROR_INVALID_VALUE; + case UR_DEVICE_INFO_MIPMAP_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_MIPMAP_ANISOTROPY_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP: + case UR_DEVICE_INFO_MIPMAP_LEVEL_REFERENCE_SUPPORT_EXP: + case UR_DEVICE_INFO_INTEROP_MEMORY_IMPORT_SUPPORT_EXP: + case UR_DEVICE_INFO_INTEROP_MEMORY_EXPORT_SUPPORT_EXP: + case UR_DEVICE_INFO_INTEROP_SEMAPHORE_IMPORT_SUPPORT_EXP: + case UR_DEVICE_INFO_INTEROP_SEMAPHORE_EXPORT_SUPPORT_EXP: default: urPrint("Unsupported ParamName in urGetDeviceInfo\n"); urPrint("ParamName=%d(0x%x)\n", ParamName, ParamName); diff --git a/source/adapters/level_zero/image.cpp b/source/adapters/level_zero/image.cpp index 3e2b78488f..0d986de0ad 100644 --- a/source/adapters/level_zero/image.cpp +++ b/source/adapters/level_zero/image.cpp @@ -10,23 +10,579 @@ #include "image.hpp" #include "common.hpp" +#include "context.hpp" +#include "event.hpp" +#include "sampler.hpp" +#include "ur_level_zero.hpp" + +typedef ze_result_t(ZE_APICALL *zeImageGetDeviceOffsetExp_pfn)( + ze_image_handle_t hImage, uint64_t *pDeviceOffset); + +typedef ze_result_t(ZE_APICALL *zeMemGetPitchFor2dImage_pfn)( + ze_context_handle_t hContext, ze_device_handle_t hDevice, size_t imageWidth, + size_t imageHeight, unsigned int elementSizeInBytes, size_t *rowPitch); + +namespace { + +zeMemGetPitchFor2dImage_pfn zeMemGetPitchFor2dImageFunctionPtr = nullptr; + +zeImageGetDeviceOffsetExp_pfn zeImageGetDeviceOffsetExpFunctionPtr = nullptr; + +/// Return true if the two image_desc are the same. +bool isSameImageDesc(const ze_image_desc_t *Desc1, + const ze_image_desc_t *Desc2) { + auto IsSameImageFormat = [](const ze_image_format_t &Format1, + const ze_image_format_t &Format2) { + return Format1.layout == Format2.layout && Format1.type == Format2.type && + Format1.x == Format2.x && Format1.y == Format2.y && + Format1.z == Format2.z && Format1.w == Format2.w; + }; + return Desc1->stype == Desc2->stype && Desc1->flags == Desc2->flags && + Desc1->type == Desc2->type && + IsSameImageFormat(Desc1->format, Desc2->format) && + Desc1->width == Desc2->width && Desc1->height == Desc2->height && + Desc1->depth == Desc2->depth && + Desc1->arraylevels == Desc2->arraylevels && + Desc1->miplevels == Desc2->miplevels; +} + +/// Construct UR image format from ZE image desc. +ur_result_t ze2urImageFormat(const ze_image_desc_t *ZeImageDesc, + ur_image_format_t *UrImageFormat) { + const ze_image_format_t &ZeImageFormat = ZeImageDesc->format; + size_t ZeImageFormatTypeSize; + switch (ZeImageFormat.layout) { + case ZE_IMAGE_FORMAT_LAYOUT_8: + case ZE_IMAGE_FORMAT_LAYOUT_8_8: + case ZE_IMAGE_FORMAT_LAYOUT_8_8_8_8: + ZeImageFormatTypeSize = 8; + break; + case ZE_IMAGE_FORMAT_LAYOUT_16: + case ZE_IMAGE_FORMAT_LAYOUT_16_16: + case ZE_IMAGE_FORMAT_LAYOUT_16_16_16_16: + ZeImageFormatTypeSize = 16; + break; + case ZE_IMAGE_FORMAT_LAYOUT_32: + case ZE_IMAGE_FORMAT_LAYOUT_32_32: + case ZE_IMAGE_FORMAT_LAYOUT_32_32_32_32: + ZeImageFormatTypeSize = 32; + break; + default: + urPrint("ze2urImageFormat: unsupported image format layout: layout = %d\n", + ZeImageFormat.layout); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + ur_image_channel_order_t ChannelOrder; + switch (ZeImageFormat.layout) { + case ZE_IMAGE_FORMAT_LAYOUT_8: + case ZE_IMAGE_FORMAT_LAYOUT_16: + case ZE_IMAGE_FORMAT_LAYOUT_32: + switch (ZeImageFormat.x) { + case ZE_IMAGE_FORMAT_SWIZZLE_R: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_R; + break; + case ZE_IMAGE_FORMAT_SWIZZLE_A: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_A; + break; + default: + urPrint("ze2urImageFormat: unexpected image format channel x: x = %d\n", + ZeImageFormat.x); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_LAYOUT_8_8: + case ZE_IMAGE_FORMAT_LAYOUT_16_16: + case ZE_IMAGE_FORMAT_LAYOUT_32_32: + if (ZeImageFormat.x != ZE_IMAGE_FORMAT_SWIZZLE_R) { + urPrint("ze2urImageFormat: unexpected image format channel x: x = %d\n", + ZeImageFormat.x); + return UR_RESULT_ERROR_INVALID_VALUE; + } + switch (ZeImageFormat.y) { + case ZE_IMAGE_FORMAT_SWIZZLE_G: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_RG; + break; + case ZE_IMAGE_FORMAT_SWIZZLE_A: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_RA; + break; + case ZE_IMAGE_FORMAT_SWIZZLE_X: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_RX; + break; + default: + urPrint("ze2urImageFormat: unexpected image format channel y: y = %d\n", + ZeImageFormat.x); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_LAYOUT_8_8_8_8: + case ZE_IMAGE_FORMAT_LAYOUT_16_16_16_16: + case ZE_IMAGE_FORMAT_LAYOUT_32_32_32_32: + if (ZeImageFormat.x == ZE_IMAGE_FORMAT_SWIZZLE_R && + ZeImageFormat.y == ZE_IMAGE_FORMAT_SWIZZLE_G && + ZeImageFormat.z == ZE_IMAGE_FORMAT_SWIZZLE_B) { + switch (ZeImageFormat.w) { + case ZE_IMAGE_FORMAT_SWIZZLE_X: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_RGBX; + break; + case ZE_IMAGE_FORMAT_SWIZZLE_A: + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_RGBA; + break; + default: + urPrint("ze2urImageFormat: unexpected image format channel w: w = %d\n", + ZeImageFormat.x); + return UR_RESULT_ERROR_INVALID_VALUE; + } + } else if (ZeImageFormat.x == ZE_IMAGE_FORMAT_SWIZZLE_A && + ZeImageFormat.y == ZE_IMAGE_FORMAT_SWIZZLE_R && + ZeImageFormat.z == ZE_IMAGE_FORMAT_SWIZZLE_G && + ZeImageFormat.w == ZE_IMAGE_FORMAT_SWIZZLE_B) { + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_ARGB; + } else if (ZeImageFormat.x == ZE_IMAGE_FORMAT_SWIZZLE_B && + ZeImageFormat.y == ZE_IMAGE_FORMAT_SWIZZLE_G && + ZeImageFormat.z == ZE_IMAGE_FORMAT_SWIZZLE_R && + ZeImageFormat.w == ZE_IMAGE_FORMAT_SWIZZLE_A) { + ChannelOrder = UR_IMAGE_CHANNEL_ORDER_BGRA; + } else { + urPrint("ze2urImageFormat: unexpected image format channel\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + default: + urPrint("ze2urImageFormat: unsupported image format layout: layout = %d\n", + ZeImageFormat.layout); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + ur_image_channel_type_t ChannelType; + switch (ZeImageFormat.type) { + case ZE_IMAGE_FORMAT_TYPE_UINT: + switch (ZeImageFormatTypeSize) { + case 8: + ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8; + break; + case 16: + ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16; + break; + case 32: + ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32; + break; + default: + urPrint( + "ze2urImageFormat: unexpected image format type size: size = %zu\n", + ZeImageFormatTypeSize); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_TYPE_SINT: + switch (ZeImageFormatTypeSize) { + case 8: + ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8; + break; + case 16: + ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16; + break; + case 32: + ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32; + break; + default: + urPrint( + "ze2urImageFormat: unexpected image format type size: size = %zu\n", + ZeImageFormatTypeSize); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_TYPE_UNORM: + switch (ZeImageFormatTypeSize) { + case 8: + ChannelType = UR_IMAGE_CHANNEL_TYPE_UNORM_INT8; + break; + case 16: + ChannelType = UR_IMAGE_CHANNEL_TYPE_UNORM_INT16; + break; + default: + urPrint( + "ze2urImageFormat: unexpected image format type size: size = %zu\n", + ZeImageFormatTypeSize); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_TYPE_SNORM: + switch (ZeImageFormatTypeSize) { + case 8: + ChannelType = UR_IMAGE_CHANNEL_TYPE_SNORM_INT8; + break; + case 16: + ChannelType = UR_IMAGE_CHANNEL_TYPE_SNORM_INT16; + break; + default: + urPrint( + "ze2urImageFormat: unexpected image format type size: size = %zu\n", + ZeImageFormatTypeSize); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + case ZE_IMAGE_FORMAT_TYPE_FLOAT: + switch (ZeImageFormatTypeSize) { + case 16: + ChannelType = UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT; + break; + case 32: + ChannelType = UR_IMAGE_CHANNEL_TYPE_FLOAT; + break; + default: + urPrint( + "ze2urImageFormat: unexpected image format type size: size = %zu\n", + ZeImageFormatTypeSize); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + default: + urPrint("ze2urImageFormat: unsupported image format type: type = %d\n", + ZeImageFormat.type); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + UrImageFormat->channelOrder = ChannelOrder; + UrImageFormat->channelType = ChannelType; + return UR_RESULT_SUCCESS; +} + +/// Construct ZE image desc from UR image format and desc. +ur_result_t ur2zeImageDesc(const ur_image_format_t *ImageFormat, + const ur_image_desc_t *ImageDesc, + ZeStruct &ZeImageDesc) { + auto [ZeImageFormatType, ZeImageFormatTypeSize] = + getImageFormatTypeAndSize(ImageFormat); + // TODO: populate the layout mapping + ze_image_format_layout_t ZeImageFormatLayout; + switch (ImageFormat->channelOrder) { + case UR_IMAGE_CHANNEL_ORDER_A: + case UR_IMAGE_CHANNEL_ORDER_R: { + switch (ZeImageFormatTypeSize) { + case 8: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_8; + break; + case 16: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_16; + break; + case 32: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_32; + break; + default: + urPrint("ur2zeImageDesc: unexpected data type size\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + } + case UR_IMAGE_CHANNEL_ORDER_RG: + case UR_IMAGE_CHANNEL_ORDER_RA: + case UR_IMAGE_CHANNEL_ORDER_RX: { + switch (ZeImageFormatTypeSize) { + case 8: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_8_8; + break; + case 16: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_16_16; + break; + case 32: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_32_32; + break; + default: + urPrint("ur2zeImageDesc: unexpected data type size\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + } + case UR_IMAGE_CHANNEL_ORDER_RGBX: + case UR_IMAGE_CHANNEL_ORDER_RGBA: + case UR_IMAGE_CHANNEL_ORDER_ARGB: + case UR_IMAGE_CHANNEL_ORDER_BGRA: { + switch (ZeImageFormatTypeSize) { + case 8: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_8_8_8_8; + break; + case 16: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_16_16_16_16; + break; + case 32: + ZeImageFormatLayout = ZE_IMAGE_FORMAT_LAYOUT_32_32_32_32; + break; + default: + urPrint("ur2zeImageDesc: unexpected data type size\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + break; + } + default: + urPrint("format channel order = %d\n", ImageFormat->channelOrder); + die("ur2zeImageDesc: unsupported image channel order\n"); + break; + } + + ze_image_format_t ZeFormatDesc = { + ZeImageFormatLayout, ZeImageFormatType, + // TODO: are swizzles deducted from image_format->image_channel_order? + ZE_IMAGE_FORMAT_SWIZZLE_R, ZE_IMAGE_FORMAT_SWIZZLE_G, + ZE_IMAGE_FORMAT_SWIZZLE_B, ZE_IMAGE_FORMAT_SWIZZLE_A}; + + ze_image_type_t ZeImageType; + switch (ImageDesc->type) { + case UR_MEM_TYPE_IMAGE1D: + ZeImageType = ZE_IMAGE_TYPE_1D; + break; + case UR_MEM_TYPE_IMAGE2D: + ZeImageType = ZE_IMAGE_TYPE_2D; + break; + case UR_MEM_TYPE_IMAGE3D: + ZeImageType = ZE_IMAGE_TYPE_3D; + break; + case UR_MEM_TYPE_IMAGE1D_ARRAY: + ZeImageType = ZE_IMAGE_TYPE_1DARRAY; + break; + case UR_MEM_TYPE_IMAGE2D_ARRAY: + ZeImageType = ZE_IMAGE_TYPE_2DARRAY; + break; + default: + urPrint("ur2zeImageDesc: unsupported image type\n"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + ZeImageDesc.stype = ZE_STRUCTURE_TYPE_IMAGE_DESC; + ZeImageDesc.pNext = ImageDesc->pNext; + ZeImageDesc.arraylevels = ZeImageDesc.flags = 0; + ZeImageDesc.type = ZeImageType; + ZeImageDesc.format = ZeFormatDesc; + ZeImageDesc.width = ur_cast(ImageDesc->width); + ZeImageDesc.height = + std::max(ur_cast(ImageDesc->height), (uint64_t)1); + ZeImageDesc.depth = + std::max(ur_cast(ImageDesc->depth), (uint64_t)1); + ZeImageDesc.arraylevels = ur_cast(ImageDesc->arraySize); + ZeImageDesc.miplevels = ImageDesc->numMipLevel; + + return UR_RESULT_SUCCESS; +} + +/// Return element size in bytes of a pixel. +uint32_t getPixelSizeBytes(const ur_image_format_t *Format) { + uint32_t NumChannels = 0; + switch (Format->channelOrder) { + case UR_IMAGE_CHANNEL_ORDER_A: + case UR_IMAGE_CHANNEL_ORDER_R: + case UR_IMAGE_CHANNEL_ORDER_INTENSITY: + case UR_IMAGE_CHANNEL_ORDER_LUMINANCE: + case UR_IMAGE_CHANNEL_ORDER_FORCE_UINT32: + NumChannels = 1; + break; + case UR_IMAGE_CHANNEL_ORDER_RG: + case UR_IMAGE_CHANNEL_ORDER_RA: + case UR_IMAGE_CHANNEL_ORDER_RX: + NumChannels = 2; + break; + case UR_IMAGE_CHANNEL_ORDER_RGB: + case UR_IMAGE_CHANNEL_ORDER_RGX: + NumChannels = 3; + break; + case UR_IMAGE_CHANNEL_ORDER_RGBA: + case UR_IMAGE_CHANNEL_ORDER_BGRA: + case UR_IMAGE_CHANNEL_ORDER_ARGB: + case UR_IMAGE_CHANNEL_ORDER_ABGR: + case UR_IMAGE_CHANNEL_ORDER_RGBX: + case UR_IMAGE_CHANNEL_ORDER_SRGBA: + NumChannels = 4; + break; + default: + ur::unreachable(); + } + uint32_t ChannelTypeSizeInBytes = 0; + switch (Format->channelType) { + case UR_IMAGE_CHANNEL_TYPE_SNORM_INT8: + case UR_IMAGE_CHANNEL_TYPE_UNORM_INT8: + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8: + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8: + ChannelTypeSizeInBytes = 1; + break; + case UR_IMAGE_CHANNEL_TYPE_SNORM_INT16: + case UR_IMAGE_CHANNEL_TYPE_UNORM_INT16: + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16: + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16: + case UR_IMAGE_CHANNEL_TYPE_UNORM_SHORT_565: + case UR_IMAGE_CHANNEL_TYPE_UNORM_SHORT_555: + ChannelTypeSizeInBytes = 2; + break; + case UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT: + case UR_IMAGE_CHANNEL_TYPE_INT_101010: + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32: + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32: + case UR_IMAGE_CHANNEL_TYPE_FLOAT: + case UR_IMAGE_CHANNEL_TYPE_FORCE_UINT32: + ChannelTypeSizeInBytes = 4; + break; + default: + ur::unreachable(); + } + return NumChannels * ChannelTypeSizeInBytes; +} + +} // namespace + +ur_result_t getImageRegionHelper(ze_image_desc_t ZeImageDesc, + ur_rect_offset_t *Origin, + ur_rect_region_t *Region, + ze_image_region_t &ZeRegion) { + UR_ASSERT(Origin, UR_RESULT_ERROR_INVALID_VALUE); + UR_ASSERT(Region, UR_RESULT_ERROR_INVALID_VALUE); + + if (ZeImageDesc.type == ZE_IMAGE_TYPE_1D) { + Region->height = 1; + Region->depth = 1; + } else if (ZeImageDesc.type == ZE_IMAGE_TYPE_2D || + ZeImageDesc.type == ZE_IMAGE_TYPE_1DARRAY) { + Region->depth = 1; + } + +#ifndef NDEBUG + UR_ASSERT((ZeImageDesc.type == ZE_IMAGE_TYPE_1D && Origin->y == 0 && + Origin->z == 0) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_1DARRAY && Origin->z == 0) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_2D && Origin->z == 0) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_3D), + UR_RESULT_ERROR_INVALID_VALUE); + + UR_ASSERT(Region->width && Region->height && Region->depth, + UR_RESULT_ERROR_INVALID_VALUE); + UR_ASSERT( + (ZeImageDesc.type == ZE_IMAGE_TYPE_1D && Region->height == 1 && + Region->depth == 1) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_1DARRAY && Region->depth == 1) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_2D && Region->depth == 1) || + (ZeImageDesc.type == ZE_IMAGE_TYPE_3D), + UR_RESULT_ERROR_INVALID_VALUE); +#endif // !NDEBUG + + uint32_t OriginX = ur_cast(Origin->x); + uint32_t OriginY = ur_cast(Origin->y); + uint32_t OriginZ = ur_cast(Origin->z); + + uint32_t Width = ur_cast(Region->width); + uint32_t Height = ur_cast(Region->height); + uint32_t Depth = ur_cast(Region->depth); + + ZeRegion = {OriginX, OriginY, OriginZ, Width, Height, Depth}; + + return UR_RESULT_SUCCESS; +} + +std::pair +getImageFormatTypeAndSize(const ur_image_format_t *ImageFormat) { + ze_image_format_type_t ZeImageFormatType; + size_t ZeImageFormatTypeSize; + switch (ImageFormat->channelType) { + case UR_IMAGE_CHANNEL_TYPE_FLOAT: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_FLOAT; + ZeImageFormatTypeSize = 32; + break; + } + case UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_FLOAT; + ZeImageFormatTypeSize = 16; + break; + } + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; + ZeImageFormatTypeSize = 32; + break; + } + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; + ZeImageFormatTypeSize = 16; + break; + } + case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; + ZeImageFormatTypeSize = 8; + break; + } + case UR_IMAGE_CHANNEL_TYPE_UNORM_INT16: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UNORM; + ZeImageFormatTypeSize = 16; + break; + } + case UR_IMAGE_CHANNEL_TYPE_UNORM_INT8: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UNORM; + ZeImageFormatTypeSize = 8; + break; + } + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; + ZeImageFormatTypeSize = 32; + break; + } + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; + ZeImageFormatTypeSize = 16; + break; + } + case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; + ZeImageFormatTypeSize = 8; + break; + } + case UR_IMAGE_CHANNEL_TYPE_SNORM_INT16: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SNORM; + ZeImageFormatTypeSize = 16; + break; + } + case UR_IMAGE_CHANNEL_TYPE_SNORM_INT8: { + ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SNORM; + ZeImageFormatTypeSize = 8; + break; + } + default: + urPrint("urMemImageCreate: unsupported image data type: data type = %d\n", + ImageFormat->channelType); + ur::unreachable(); + } + return {ZeImageFormatType, ZeImageFormatTypeSize}; +} UR_APIEXPORT ur_result_t UR_APICALL urUSMPitchedAllocExp( ur_context_handle_t hContext, ur_device_handle_t hDevice, const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t pool, size_t widthInBytes, size_t height, size_t elementSizeBytes, void **ppMem, size_t *pResultPitch) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = pUSMDesc; - std::ignore = pool; - std::ignore = widthInBytes; - std::ignore = height; - std::ignore = elementSizeBytes; - std::ignore = ppMem; - std::ignore = pResultPitch; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + std::shared_lock Lock(hContext->Mutex); + + UR_ASSERT(hContext && hDevice, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(widthInBytes != 0, UR_RESULT_ERROR_INVALID_USM_SIZE); + UR_ASSERT(ppMem && pResultPitch, UR_RESULT_ERROR_INVALID_NULL_POINTER); + + static std::once_flag InitFlag; + std::call_once(InitFlag, [&]() { + ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; + auto Result = zeDriverGetExtensionFunctionAddress( + DriverHandle, "zeMemGetPitchFor2dImage", + (void **)&zeMemGetPitchFor2dImageFunctionPtr); + if (Result != ZE_RESULT_SUCCESS) + urPrint("zeDriverGetExtensionFunctionAddress zeMemGetPitchFor2dImage " + "failed, err = %d\n", + Result); + }); + if (!zeMemGetPitchFor2dImageFunctionPtr) + return UR_RESULT_ERROR_INVALID_OPERATION; + + size_t Width = widthInBytes / elementSizeBytes; + size_t RowPitch; + ZE2UR_CALL(zeMemGetPitchFor2dImageFunctionPtr, + (hContext->ZeContext, hDevice->ZeDevice, Width, height, + elementSizeBytes, &RowPitch)); + *pResultPitch = RowPitch; + + size_t Size = height * RowPitch; + UR_CALL(urUSMDeviceAlloc(hContext, hDevice, pUSMDesc, pool, Size, ppMem)); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL @@ -36,32 +592,47 @@ urBindlessImagesUnsampledImageHandleDestroyExp(ur_context_handle_t hContext, std::ignore = hContext; std::ignore = hDevice; std::ignore = hImage; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageHandleDestroyExp(ur_context_handle_t hContext, ur_device_handle_t hDevice, ur_exp_image_handle_t hImage) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = hImage; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + // Sampled image is a combination of unsampled image and sampler. + // Sampler is released in urSamplerRelease. + return urBindlessImagesUnsampledImageHandleDestroyExp(hContext, hDevice, + hImage); } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( ur_context_handle_t hContext, ur_device_handle_t hDevice, const ur_image_format_t *pImageFormat, const ur_image_desc_t *pImageDesc, ur_exp_image_mem_handle_t *phImageMem) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = pImageFormat; - std::ignore = pImageDesc; - std::ignore = phImageMem; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + std::shared_lock Lock(hContext->Mutex); + + UR_ASSERT(hContext && hDevice, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(pImageFormat && pImageDesc && phImageMem, + UR_RESULT_ERROR_INVALID_NULL_POINTER); + + ZeStruct ZeImageDesc; + UR_CALL(ur2zeImageDesc(pImageFormat, pImageDesc, ZeImageDesc)); + + ze_image_bindless_exp_desc_t ZeImageBindlessDesc; + ZeImageBindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC; + ZeImageBindlessDesc.pNext = nullptr; + ZeImageBindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS; + ZeImageDesc.pNext = &ZeImageBindlessDesc; + + ze_image_handle_t ZeImage; + ZE2UR_CALL(zeImageCreate, + (hContext->ZeContext, hDevice->ZeDevice, &ZeImageDesc, &ZeImage)); + ZE2UR_CALL(zeContextMakeImageResident, + (hContext->ZeContext, hDevice->ZeDevice, ZeImage)); + UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true, + ZeImageDesc, phImageMem)); + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageFreeExp( @@ -69,9 +640,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageFreeExp( ur_exp_image_mem_handle_t hImageMem) { std::ignore = hContext; std::ignore = hDevice; - std::ignore = hImageMem; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + UR_CALL(urMemRelease(reinterpret_cast(hImageMem))); + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( @@ -79,15 +649,83 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( ur_exp_image_mem_handle_t hImageMem, const ur_image_format_t *pImageFormat, const ur_image_desc_t *pImageDesc, ur_mem_handle_t *phMem, ur_exp_image_handle_t *phImage) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = hImageMem; - std::ignore = pImageFormat; - std::ignore = pImageDesc; - std::ignore = phMem; - std::ignore = phImage; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + std::shared_lock Lock(hContext->Mutex); + + UR_ASSERT(hContext && hDevice && hImageMem, + UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(pImageFormat && pImageDesc && phMem && phImage, + UR_RESULT_ERROR_INVALID_NULL_POINTER); + + ZeStruct ZeImageDesc; + UR_CALL(ur2zeImageDesc(pImageFormat, pImageDesc, ZeImageDesc)); + + ze_image_handle_t ZeImage; + + ze_memory_allocation_properties_t MemAllocProperties{ + ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES}; + ZE2UR_CALL(zeMemGetAllocProperties, + (hContext->ZeContext, hImageMem, &MemAllocProperties, nullptr)); + if (MemAllocProperties.type == ZE_MEMORY_TYPE_UNKNOWN) { + _ur_image *UrImage = reinterpret_cast<_ur_image *>(hImageMem); + if (!isSameImageDesc(&UrImage->ZeImageDesc, &ZeImageDesc)) { + ze_image_bindless_exp_desc_t ZeImageBindlessDesc; + ZeImageBindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC; + ZeImageBindlessDesc.pNext = nullptr; + ZeImageBindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS; + ZeImageDesc.pNext = &ZeImageBindlessDesc; + ZE2UR_CALL(zeImageViewCreateExt, + (hContext->ZeContext, hDevice->ZeDevice, &ZeImageDesc, + UrImage->ZeImage, &ZeImage)); + ZE2UR_CALL(zeContextMakeImageResident, + (hContext->ZeContext, hDevice->ZeDevice, ZeImage)); + UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true, + ZeImageDesc, phMem)); + } else { + ZeImage = UrImage->ZeImage; + *phMem = nullptr; + } + } else if (MemAllocProperties.type == ZE_MEMORY_TYPE_DEVICE) { + ze_image_pitched_exp_desc_t PitchedDesc; + PitchedDesc.stype = ZE_STRUCTURE_TYPE_PITCHED_IMAGE_EXP_DESC; + PitchedDesc.pNext = nullptr; + PitchedDesc.ptr = hImageMem; + + ze_image_bindless_exp_desc_t BindlessDesc; + BindlessDesc.stype = ZE_STRUCTURE_TYPE_BINDLESS_IMAGE_EXP_DESC; + BindlessDesc.pNext = &PitchedDesc; + BindlessDesc.flags = ZE_IMAGE_BINDLESS_EXP_FLAG_BINDLESS; + + ZeImageDesc.pNext = &BindlessDesc; + + ZE2UR_CALL(zeImageCreate, (hContext->ZeContext, hDevice->ZeDevice, + &ZeImageDesc, &ZeImage)); + ZE2UR_CALL(zeContextMakeImageResident, + (hContext->ZeContext, hDevice->ZeDevice, ZeImage)); + UR_CALL(createUrMemFromZeImage(hContext, ZeImage, /*OwnZeMemHandle*/ true, + ZeImageDesc, phMem)); + } else { + return UR_RESULT_ERROR_INVALID_VALUE; + } + + static std::once_flag InitFlag; + std::call_once(InitFlag, [&]() { + ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; + auto Result = zeDriverGetExtensionFunctionAddress( + DriverHandle, "zeImageGetDeviceOffsetExp", + (void **)&zeImageGetDeviceOffsetExpFunctionPtr); + if (Result != ZE_RESULT_SUCCESS) + urPrint("zeDriverGetExtensionFunctionAddress zeImageGetDeviceOffsetExp " + "failed, err = %d\n", + Result); + }); + if (!zeImageGetDeviceOffsetExpFunctionPtr) + return UR_RESULT_ERROR_INVALID_OPERATION; + + uint64_t DeviceOffset{}; + ZE2UR_CALL(zeImageGetDeviceOffsetExpFunctionPtr, (ZeImage, &DeviceOffset)); + *phImage = reinterpret_cast(DeviceOffset); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( @@ -95,16 +733,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( ur_exp_image_mem_handle_t hImageMem, const ur_image_format_t *pImageFormat, const ur_image_desc_t *pImageDesc, ur_sampler_handle_t hSampler, ur_mem_handle_t *phMem, ur_exp_image_handle_t *phImage) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = hImageMem; - std::ignore = pImageFormat; - std::ignore = pImageDesc; - std::ignore = hSampler; - std::ignore = phMem; - std::ignore = phImage; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + UR_CALL(urBindlessImagesUnsampledImageCreateExp( + hContext, hDevice, hImageMem, pImageFormat, pImageDesc, phMem, phImage)); + + struct combined_sampled_image_handle { + uint64_t raw_image_handle; + uint64_t raw_sampler_handle; + }; + combined_sampled_image_handle *sampledImageHandle = + reinterpret_cast(phImage); + sampledImageHandle->raw_image_handle = reinterpret_cast(*phImage); + sampledImageHandle->raw_sampler_handle = + reinterpret_cast(hSampler->ZeSampler); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp( @@ -114,32 +757,163 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp( ur_rect_offset_t dstOffset, ur_rect_region_t copyExtent, ur_rect_region_t hostExtent, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - std::ignore = hQueue; - std::ignore = pDst; - std::ignore = pSrc; - std::ignore = pImageFormat; - std::ignore = pImageDesc; - std::ignore = imageCopyFlags; - std::ignore = srcOffset; - std::ignore = dstOffset; - std::ignore = copyExtent; - std::ignore = hostExtent; - std::ignore = numEventsInWaitList; - std::ignore = phEventWaitList; - std::ignore = phEvent; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + std::scoped_lock Lock(hQueue->Mutex); + + UR_ASSERT(hQueue, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(pDst && pSrc && pImageFormat && pImageDesc, + UR_RESULT_ERROR_INVALID_NULL_POINTER); + UR_ASSERT(!(UR_EXP_IMAGE_COPY_FLAGS_MASK & imageCopyFlags), + UR_RESULT_ERROR_INVALID_ENUMERATION); + UR_ASSERT(!(pImageDesc && UR_MEM_TYPE_IMAGE1D_BUFFER < pImageDesc->type), + UR_RESULT_ERROR_INVALID_IMAGE_FORMAT_DESCRIPTOR); + + ZeStruct ZeImageDesc; + UR_CALL(ur2zeImageDesc(pImageFormat, pImageDesc, ZeImageDesc)); + + bool UseCopyEngine = hQueue->useCopyEngine(/*PreferCopyEngine*/ true); + + _ur_ze_event_list_t TmpWaitList; + UR_CALL(TmpWaitList.createAndRetainUrZeEventList( + numEventsInWaitList, phEventWaitList, hQueue, UseCopyEngine)); + + bool Blocking = false; + // We want to batch these commands to avoid extra submissions (costly) + bool OkToBatch = true; + + // Get a new command list to be used on this call + ur_command_list_ptr_t CommandList{}; + UR_CALL(hQueue->Context->getAvailableCommandList(hQueue, CommandList, + UseCopyEngine, OkToBatch)); + + ze_event_handle_t ZeEvent = nullptr; + ur_event_handle_t InternalEvent; + bool IsInternal = phEvent == nullptr; + ur_event_handle_t *Event = phEvent ? phEvent : &InternalEvent; + UR_CALL(createEventAndAssociateQueue(hQueue, Event, UR_COMMAND_MEM_IMAGE_COPY, + CommandList, IsInternal, + /*IsMultiDevice*/ false)); + ZeEvent = (*Event)->ZeEvent; + (*Event)->WaitList = TmpWaitList; + + const auto &ZeCommandList = CommandList->first; + const auto &WaitList = (*Event)->WaitList; + + if (imageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_HOST_TO_DEVICE) { + if (pImageDesc->rowPitch == 0) { + // Copy to Non-USM memory + ze_image_region_t DstRegion; + UR_CALL(getImageRegionHelper(ZeImageDesc, &dstOffset, ©Extent, + DstRegion)); + auto *UrImage = static_cast<_ur_image *>(pDst); + ZE2UR_CALL(zeCommandListAppendImageCopyFromMemory, + (ZeCommandList, UrImage->ZeImage, pSrc, &DstRegion, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); + } else { + // Copy to pitched USM memory + uint32_t DstPitch = pImageDesc->rowPitch; + ze_copy_region_t ZeDstRegion = { + (uint32_t)dstOffset.x, (uint32_t)dstOffset.y, + (uint32_t)dstOffset.z, DstPitch, + (uint32_t)copyExtent.height, (uint32_t)copyExtent.depth}; + uint32_t DstSlicePitch = 0; + uint32_t SrcPitch = hostExtent.width * getPixelSizeBytes(pImageFormat); + ze_copy_region_t ZeSrcRegion = { + (uint32_t)srcOffset.x, (uint32_t)srcOffset.y, + (uint32_t)srcOffset.z, SrcPitch, + (uint32_t)copyExtent.height, (uint32_t)copyExtent.depth}; + uint32_t SrcSlicePitch = 0; + ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion, + (ZeCommandList, pDst, &ZeDstRegion, DstPitch, DstSlicePitch, + pSrc, &ZeSrcRegion, SrcPitch, SrcSlicePitch, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); + } + } else if (imageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_DEVICE_TO_HOST) { + if (pImageDesc->rowPitch == 0) { + // Copy from Non-USM memory to host + ze_image_region_t SrcRegion; + UR_CALL(getImageRegionHelper(ZeImageDesc, &srcOffset, ©Extent, + SrcRegion)); + auto *UrImage = static_cast<_ur_image *>(pSrc); + ZE2UR_CALL(zeCommandListAppendImageCopyToMemory, + (ZeCommandList, pDst, UrImage->ZeImage, &SrcRegion, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); + } else { + // Copy from pitched USM memory to host + uint32_t DstPitch = copyExtent.width * getPixelSizeBytes(pImageFormat); + ze_copy_region_t ZeDstRegion = { + (uint32_t)dstOffset.x, (uint32_t)dstOffset.y, + (uint32_t)dstOffset.z, DstPitch, + (uint32_t)copyExtent.height, (uint32_t)copyExtent.depth}; + uint32_t DstSlicePitch = 0; + uint32_t SrcPitch = pImageDesc->rowPitch; + ze_copy_region_t ZeSrcRegion = { + (uint32_t)srcOffset.x, (uint32_t)srcOffset.y, + (uint32_t)srcOffset.z, SrcPitch, + (uint32_t)copyExtent.height, (uint32_t)copyExtent.depth}; + uint32_t SrcSlicePitch = 0; + ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion, + (ZeCommandList, pDst, &ZeDstRegion, DstPitch, DstSlicePitch, + pSrc, &ZeSrcRegion, SrcPitch, SrcSlicePitch, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); + } + } else { + urPrint("urBindlessImagesImageCopyExp: unexpected imageCopyFlags\n"); + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + UR_CALL(hQueue->executeCommandList(CommandList, Blocking, OkToBatch)); + + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( ur_exp_image_mem_handle_t hImageMem, ur_image_info_t propName, void *pPropValue, size_t *pPropSizeRet) { - std::ignore = hImageMem; - std::ignore = propName; - std::ignore = pPropValue; - std::ignore = pPropSizeRet; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + UR_ASSERT(hImageMem, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + UR_ASSERT(UR_IMAGE_INFO_DEPTH >= propName, + UR_RESULT_ERROR_INVALID_ENUMERATION); + UR_ASSERT(pPropValue || pPropSizeRet, UR_RESULT_ERROR_INVALID_NULL_POINTER); + + auto *UrImage = reinterpret_cast<_ur_image *>(hImageMem); + ze_image_desc_t &Desc = UrImage->ZeImageDesc; + switch (propName) { + case UR_IMAGE_INFO_WIDTH: + if (pPropValue) { + *(uint64_t *)pPropValue = Desc.width; + } + if (pPropSizeRet) { + *pPropSizeRet = sizeof(uint64_t); + } + return UR_RESULT_SUCCESS; + case UR_IMAGE_INFO_HEIGHT: + if (pPropValue) { + *(uint32_t *)pPropValue = Desc.height; + } + if (pPropSizeRet) { + *pPropSizeRet = sizeof(uint32_t); + } + return UR_RESULT_SUCCESS; + case UR_IMAGE_INFO_DEPTH: + if (pPropValue) { + *(uint32_t *)pPropValue = Desc.depth; + } + if (pPropSizeRet) { + *pPropSizeRet = sizeof(uint32_t); + } + return UR_RESULT_SUCCESS; + case UR_IMAGE_INFO_FORMAT: + if (pPropValue) { + ur_image_format_t UrImageFormat; + UR_CALL(ze2urImageFormat(&Desc, &UrImageFormat)); + *(ur_image_format_t *)pPropValue = UrImageFormat; + } + if (pPropSizeRet) { + *pPropSizeRet = sizeof(ur_image_format_t); + } + return UR_RESULT_SUCCESS; + default: + return UR_RESULT_ERROR_INVALID_VALUE; + } } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( @@ -158,11 +932,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( ur_context_handle_t hContext, ur_device_handle_t hDevice, ur_exp_image_mem_handle_t hMem) { - std::ignore = hContext; - std::ignore = hDevice; - std::ignore = hMem; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + return urBindlessImagesImageFreeExp(hContext, hDevice, hMem); } UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImportOpaqueFDExp( diff --git a/source/adapters/level_zero/image.hpp b/source/adapters/level_zero/image.hpp index d579a24708..618258601d 100644 --- a/source/adapters/level_zero/image.hpp +++ b/source/adapters/level_zero/image.hpp @@ -13,3 +13,11 @@ #include #include #include + +ur_result_t getImageRegionHelper(ze_image_desc_t ZeImageDesc, + ur_rect_offset_t *Origin, + ur_rect_region_t *Region, + ze_image_region_t &ZeRegion); + +std::pair +getImageFormatTypeAndSize(const ur_image_format_t *ImageFormat); diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index e977d1ac15..3fb0d79a61 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -14,6 +14,7 @@ #include "context.hpp" #include "event.hpp" +#include "image.hpp" #include "ur_level_zero.hpp" // Default to using compute engine for fill operation, but allow to @@ -287,49 +288,6 @@ static ur_result_t ZeHostMemAllocHelper(void **ResultPtr, return UR_RESULT_SUCCESS; } -static ur_result_t getImageRegionHelper(_ur_image *Mem, - ur_rect_offset_t *Origin, - ur_rect_region_t *Region, - ze_image_region_t &ZeRegion) { - UR_ASSERT(Mem, UR_RESULT_ERROR_INVALID_MEM_OBJECT); - UR_ASSERT(Origin, UR_RESULT_ERROR_INVALID_VALUE); - -#ifndef NDEBUG - auto UrImage = static_cast<_ur_image *>(Mem); - ze_image_desc_t &ZeImageDesc = UrImage->ZeImageDesc; - - UR_ASSERT(Mem->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT); - UR_ASSERT((ZeImageDesc.type == ZE_IMAGE_TYPE_1D && Origin->y == 0 && - Origin->z == 0) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_1DARRAY && Origin->z == 0) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_2D && Origin->z == 0) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_3D), - UR_RESULT_ERROR_INVALID_VALUE); - - UR_ASSERT(Region->width && Region->height && Region->depth, - UR_RESULT_ERROR_INVALID_VALUE); - UR_ASSERT( - (ZeImageDesc.type == ZE_IMAGE_TYPE_1D && Region->height == 1 && - Region->depth == 1) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_1DARRAY && Region->depth == 1) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_2D && Region->depth == 1) || - (ZeImageDesc.type == ZE_IMAGE_TYPE_3D), - UR_RESULT_ERROR_INVALID_VALUE); -#endif // !NDEBUG - - uint32_t OriginX = ur_cast(Origin->x); - uint32_t OriginY = ur_cast(Origin->y); - uint32_t OriginZ = ur_cast(Origin->z); - - uint32_t Width = ur_cast(Region->width); - uint32_t Height = ur_cast(Region->height); - uint32_t Depth = ur_cast(Region->depth); - - ZeRegion = {OriginX, OriginY, OriginZ, Width, Height, Depth}; - - return UR_RESULT_SUCCESS; -} - // Helper function to implement image read/write/copy. // PI interfaces must have queue's and destination image's mutexes locked for // exclusive use and source image's mutex locked for shared use on entry. @@ -372,7 +330,8 @@ static ur_result_t enqueueMemImageCommandHelper( _ur_image *SrcMem = ur_cast<_ur_image *>(const_cast(Src)); ze_image_region_t ZeSrcRegion; - UR_CALL(getImageRegionHelper(SrcMem, SrcOrigin, Region, ZeSrcRegion)); + UR_CALL(getImageRegionHelper(SrcMem->ZeImageDesc, SrcOrigin, Region, + ZeSrcRegion)); // TODO: Level Zero does not support row_pitch/slice_pitch for images yet. // Check that SYCL RT did not want pitch larger than default. @@ -406,7 +365,8 @@ static ur_result_t enqueueMemImageCommandHelper( } else if (CommandType == UR_COMMAND_MEM_IMAGE_WRITE) { _ur_image *DstMem = ur_cast<_ur_image *>(Dst); ze_image_region_t ZeDstRegion; - UR_CALL(getImageRegionHelper(DstMem, DstOrigin, Region, ZeDstRegion)); + UR_CALL(getImageRegionHelper(DstMem->ZeImageDesc, DstOrigin, Region, + ZeDstRegion)); // TODO: Level Zero does not support row_pitch/slice_pitch for images yet. // Check that SYCL RT did not want pitch larger than default. @@ -440,9 +400,11 @@ static ur_result_t enqueueMemImageCommandHelper( _ur_image *DstImage = ur_cast<_ur_image *>(Dst); ze_image_region_t ZeSrcRegion; - UR_CALL(getImageRegionHelper(SrcImage, SrcOrigin, Region, ZeSrcRegion)); + UR_CALL(getImageRegionHelper(SrcImage->ZeImageDesc, SrcOrigin, Region, + ZeSrcRegion)); ze_image_region_t ZeDstRegion; - UR_CALL(getImageRegionHelper(DstImage, DstOrigin, Region, ZeDstRegion)); + UR_CALL(getImageRegionHelper(DstImage->ZeImageDesc, DstOrigin, Region, + ZeDstRegion)); char *ZeHandleSrc = nullptr; char *ZeHandleDst = nullptr; @@ -1456,74 +1418,8 @@ static ur_result_t ur2zeImageDesc(const ur_image_format_t *ImageFormat, const ur_image_desc_t *ImageDesc, ZeStruct &ZeImageDesc) { - ze_image_format_type_t ZeImageFormatType; - size_t ZeImageFormatTypeSize; - switch (ImageFormat->channelType) { - case UR_IMAGE_CHANNEL_TYPE_FLOAT: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_FLOAT; - ZeImageFormatTypeSize = 32; - break; - } - case UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_FLOAT; - ZeImageFormatTypeSize = 16; - break; - } - case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; - ZeImageFormatTypeSize = 32; - break; - } - case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; - ZeImageFormatTypeSize = 16; - break; - } - case UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UINT; - ZeImageFormatTypeSize = 8; - break; - } - case UR_IMAGE_CHANNEL_TYPE_UNORM_INT16: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UNORM; - ZeImageFormatTypeSize = 16; - break; - } - case UR_IMAGE_CHANNEL_TYPE_UNORM_INT8: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_UNORM; - ZeImageFormatTypeSize = 8; - break; - } - case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; - ZeImageFormatTypeSize = 32; - break; - } - case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; - ZeImageFormatTypeSize = 16; - break; - } - case UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SINT; - ZeImageFormatTypeSize = 8; - break; - } - case UR_IMAGE_CHANNEL_TYPE_SNORM_INT16: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SNORM; - ZeImageFormatTypeSize = 16; - break; - } - case UR_IMAGE_CHANNEL_TYPE_SNORM_INT8: { - ZeImageFormatType = ZE_IMAGE_FORMAT_TYPE_SNORM; - ZeImageFormatTypeSize = 8; - break; - } - default: - urPrint("urMemImageCreate: unsupported image data type: data type = %d\n", - ImageFormat->channelType); - return UR_RESULT_ERROR_INVALID_VALUE; - } + auto [ZeImageFormatType, ZeImageFormatTypeSize] = + getImageFormatTypeAndSize(ImageFormat); // TODO: populate the layout mapping ze_image_format_layout_t ZeImageFormatLayout; @@ -1622,30 +1518,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate( ZE2UR_CALL(zeImageCreate, (Context->ZeContext, Device->ZeDevice, &ZeImageDesc, &ZeImage)); - try { - auto UrImage = new _ur_image(Context, ZeImage, true /*OwnZeMemHandle*/); - *Mem = reinterpret_cast(UrImage); - -#ifndef NDEBUG - UrImage->ZeImageDesc = ZeImageDesc; -#endif // !NDEBUG + UR_CALL(createUrMemFromZeImage(Context, ZeImage, /*OwnZeMemHandle*/ true, + ZeImageDesc, Mem)); - if ((Flags & UR_MEM_FLAG_USE_HOST_POINTER) != 0 || - (Flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) != 0) { - // Initialize image synchronously with immediate offload. - // zeCommandListAppendImageCopyFromMemory must not be called from - // simultaneous threads with the same command list handle, so we need - // exclusive lock. - std::scoped_lock Lock(Context->ImmediateCommandListMutex); - ZE2UR_CALL(zeCommandListAppendImageCopyFromMemory, - (Context->ZeCommandListInit, ZeImage, Host, nullptr, nullptr, - 0, nullptr)); - } - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; + if ((Flags & UR_MEM_FLAG_USE_HOST_POINTER) != 0 || + (Flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) != 0) { + // Initialize image synchronously with immediate offload. + // zeCommandListAppendImageCopyFromMemory must not be called from + // simultaneous threads with the same command list handle, so we need + // exclusive lock. + std::scoped_lock Lock(Context->ImmediateCommandListMutex); + ZE2UR_CALL(zeCommandListAppendImageCopyFromMemory, + (Context->ZeCommandListInit, ZeImage, Host, nullptr, nullptr, 0, + nullptr)); } + return UR_RESULT_SUCCESS; } @@ -1664,30 +1551,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( ze_image_handle_t ZeHImage = ur_cast(NativeMem); - _ur_image *Image = nullptr; - try { - Image = new _ur_image(Context, ZeHImage, Properties->isNativeHandleOwned); - *Mem = reinterpret_cast(Image); - + ZeStruct ZeImageDesc; #ifndef NDEBUG - ZeStruct ZeImageDesc; - ur_result_t Res = ur2zeImageDesc(ImageFormat, ImageDesc, ZeImageDesc); - if (Res != UR_RESULT_SUCCESS) { - delete Image; - *Mem = nullptr; - return Res; - } - Image->ZeImageDesc = ZeImageDesc; + ur_result_t Res = ur2zeImageDesc(ImageFormat, ImageDesc, ZeImageDesc); + if (Res != UR_RESULT_SUCCESS) { + *Mem = nullptr; + return Res; + } #else - std::ignore = ImageFormat; - std::ignore = ImageDesc; + std::ignore = ImageFormat; + std::ignore = ImageDesc; #endif // !NDEBUG - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + UR_CALL(createUrMemFromZeImage( + Context, ZeHImage, Properties->isNativeHandleOwned, ZeImageDesc, Mem)); return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/memory.hpp b/source/adapters/level_zero/memory.hpp index 8efd5b136e..1fb6d6fa31 100644 --- a/source/adapters/level_zero/memory.hpp +++ b/source/adapters/level_zero/memory.hpp @@ -206,11 +206,26 @@ struct _ur_image final : ur_mem_handle_t_ { bool isImage() const override { return true; } -#ifndef NDEBUG - // Keep the descriptor of the image (for debugging purposes) + // Keep the descriptor of the image ZeStruct ZeImageDesc; -#endif // !NDEBUG // Level Zero image handle. ze_image_handle_t ZeImage; }; + +template +ur_result_t +createUrMemFromZeImage(ur_context_handle_t Context, ze_image_handle_t ZeImage, + bool OwnZeMemHandle, + const ZeStruct &ZeImageDesc, T *UrMem) { + try { + auto UrImage = new _ur_image(Context, ZeImage, OwnZeMemHandle); + UrImage->ZeImageDesc = ZeImageDesc; + *UrMem = reinterpret_cast(UrImage); + } catch (const std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; + } + return UR_RESULT_SUCCESS; +}