diff --git a/source/adapters/cuda/memory.cpp b/source/adapters/cuda/memory.cpp index f479522fb3..f097d2474e 100644 --- a/source/adapters/cuda/memory.cpp +++ b/source/adapters/cuda/memory.cpp @@ -398,11 +398,111 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate( return Result; } -/// \TODO Not implemented -UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t, - ur_image_info_t, size_t, - void *, size_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory, + ur_image_info_t propName, + size_t propSize, + void *pPropValue, + size_t *pPropSizeRet) { + UR_ASSERT(hMemory->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT); + + auto Context = hMemory->getContext(); + + ScopedContext Active(Context); + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + try { + CUDA_ARRAY3D_DESCRIPTOR ArrayInfo; + + UR_CHECK_ERROR(cuArray3DGetDescriptor( + &ArrayInfo, std::get(hMemory->Mem).getArray())); + + const auto cuda2urFormat = [](CUarray_format CUFormat, + ur_image_channel_type_t *ChannelType) { + switch (CUFormat) { + case CU_AD_FORMAT_UNSIGNED_INT8: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8; + break; + case CU_AD_FORMAT_UNSIGNED_INT16: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16; + break; + case CU_AD_FORMAT_UNSIGNED_INT32: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32; + break; + case CU_AD_FORMAT_SIGNED_INT8: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8; + break; + case CU_AD_FORMAT_SIGNED_INT16: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16; + break; + case CU_AD_FORMAT_SIGNED_INT32: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32; + break; + case CU_AD_FORMAT_HALF: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT; + break; + case CU_AD_FORMAT_FLOAT: + *ChannelType = UR_IMAGE_CHANNEL_TYPE_FLOAT; + break; + default: + return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT; + } + return UR_RESULT_SUCCESS; + }; + + const auto cudaFormatToElementSize = [](CUarray_format CUFormat, + size_t *Size) { + switch (CUFormat) { + case CU_AD_FORMAT_UNSIGNED_INT8: + case CU_AD_FORMAT_SIGNED_INT8: + *Size = 1; + break; + case CU_AD_FORMAT_UNSIGNED_INT16: + case CU_AD_FORMAT_SIGNED_INT16: + case CU_AD_FORMAT_HALF: + *Size = 2; + break; + case CU_AD_FORMAT_UNSIGNED_INT32: + case CU_AD_FORMAT_SIGNED_INT32: + case CU_AD_FORMAT_FLOAT: + *Size = 4; + break; + default: + return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT; + } + return UR_RESULT_SUCCESS; + }; + + switch (propName) { + case UR_IMAGE_INFO_FORMAT: { + ur_image_channel_type_t ChannelType{}; + UR_CHECK_ERROR(cuda2urFormat(ArrayInfo.Format, &ChannelType)); + return ReturnValue( + ur_image_format_t{UR_IMAGE_CHANNEL_ORDER_RGBA, ChannelType}); + } + case UR_IMAGE_INFO_WIDTH: + return ReturnValue(ArrayInfo.Width); + case UR_IMAGE_INFO_HEIGHT: + return ReturnValue(ArrayInfo.Height); + case UR_IMAGE_INFO_DEPTH: + return ReturnValue(ArrayInfo.Depth); + case UR_IMAGE_INFO_ELEMENT_SIZE: { + size_t Size = 0; + UR_CHECK_ERROR(cudaFormatToElementSize(ArrayInfo.Format, &Size)); + return ReturnValue(Size); + } + case UR_IMAGE_INFO_ROW_PITCH: + case UR_IMAGE_INFO_SLICE_PITCH: + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + } catch (ur_result_t Err) { + return Err; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; + } } /// Implements a buffer partition in the CUDA backend. diff --git a/test/conformance/memory/memory_adapter_cuda.match b/test/conformance/memory/memory_adapter_cuda.match index 5fe265ae8e..b8c027fe4b 100644 --- a/test/conformance/memory/memory_adapter_cuda.match +++ b/test/conformance/memory/memory_adapter_cuda.match @@ -1,17 +1,5 @@ urMemBufferCreateWithNativeHandleTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}_ {{OPT}}urMemGetInfoImageTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_MEM_INFO_SIZE {{OPT}}urMemGetInfoImageTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_MEM_INFO_CONTEXT -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_FORMAT -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_ELEMENT_SIZE -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_ROW_PITCH -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_SLICE_PITCH -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_WIDTH -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_HEIGHT -{{OPT}}urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_DEPTH -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_FORMAT -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_ELEMENT_SIZE -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_ROW_PITCH -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_SLICE_PITCH -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_WIDTH -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_HEIGHT -{{OPT}}urMemImageGetInfoTest.InvalidSizeSmall/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_DEPTH +urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_ROW_PITCH +urMemImageGetInfoTest.Success/NVIDIA_CUDA_BACKEND___{{.*}}___UR_IMAGE_INFO_SLICE_PITCH