diff --git a/source/adapters/hip/common.hpp b/source/adapters/hip/common.hpp index 2649657f47..d7eea780a5 100644 --- a/source/adapters/hip/common.hpp +++ b/source/adapters/hip/common.hpp @@ -15,24 +15,39 @@ #include #include -// Hipify doesn't support cuArrayGetDescriptor, on AMD the hipArray can just be -// indexed, but on NVidia it is an opaque type and needs to go through -// cuArrayGetDescriptor so implement a utility function to get the array -// properties -inline void getArrayDesc(hipArray *Array, hipArray_Format &Format, - size_t &Channels) { +// Before ROCm 6, hipify doesn't support cuArrayGetDescriptor, on AMD the +// hipArray can just be indexed, but on NVidia it is an opaque type and needs to +// go through cuArrayGetDescriptor so implement a utility function to get the +// array properties +inline static hipError_t getArrayDesc(hipArray *Array, hipArray_Format &Format, + size_t &Channels) { +#if HIP_VERSION_MAJOR >= 6 + HIP_ARRAY_DESCRIPTOR ArrayDesc; + hipError_t err = hipArrayGetDescriptor(&ArrayDesc, Array); + if (err == hipSuccess) { + Format = ArrayDesc.Format; + Channels = ArrayDesc.NumChannels; + } + return err; +#else #if defined(__HIP_PLATFORM_AMD__) Format = Array->Format; Channels = Array->NumChannels; + return hipSuccess; #elif defined(__HIP_PLATFORM_NVIDIA__) CUDA_ARRAY_DESCRIPTOR ArrayDesc; - cuArrayGetDescriptor(&ArrayDesc, (CUarray)Array); - - Format = ArrayDesc.Format; - Channels = ArrayDesc.NumChannels; + CUresult err = cuArrayGetDescriptor(&ArrayDesc, (CUarray)Array); + if (err == CUDA_SUCCESS) { + Format = ArrayDesc.Format; + Channels = ArrayDesc.NumChannels; + return hipSuccess; + } else { + return hipErrorUnknown; // No easy way to map CUerror to hipError + } #else #error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); #endif +#endif } // HIP on NVIDIA headers guard hipArray3DCreate behind __CUDACC__, this does not diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index ebebcc27b5..109a248e16 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -898,7 +898,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead( hipArray_Format Format; size_t NumChannels; - getArrayDesc(Array, Format, NumChannels); + UR_CHECK_ERROR(getArrayDesc(Array, Format, NumChannels)); int ElementByteSize = imageElementByteSize(Format); @@ -959,7 +959,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite( hipArray_Format Format; size_t NumChannels; - getArrayDesc(Array, Format, NumChannels); + UR_CHECK_ERROR(getArrayDesc(Array, Format, NumChannels)); int ElementByteSize = imageElementByteSize(Format); @@ -1023,12 +1023,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( hipArray *SrcArray = std::get(hImageSrc->Mem).getArray(); hipArray_Format SrcFormat; size_t SrcNumChannels; - getArrayDesc(SrcArray, SrcFormat, SrcNumChannels); + UR_CHECK_ERROR(getArrayDesc(SrcArray, SrcFormat, SrcNumChannels)); hipArray *DstArray = std::get(hImageDst->Mem).getArray(); hipArray_Format DstFormat; size_t DstNumChannels; - getArrayDesc(DstArray, DstFormat, DstNumChannels); + UR_CHECK_ERROR(getArrayDesc(DstArray, DstFormat, DstNumChannels)); UR_ASSERT(SrcFormat == DstFormat, UR_RESULT_ERROR_INVALID_IMAGE_FORMAT_DESCRIPTOR); diff --git a/source/adapters/hip/kernel.cpp b/source/adapters/hip/kernel.cpp index cc6f4384bc..bdd5f63fb2 100644 --- a/source/adapters/hip/kernel.cpp +++ b/source/adapters/hip/kernel.cpp @@ -279,7 +279,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj( auto array = std::get(hArgValue->Mem).getArray(); hipArray_Format Format; size_t NumChannels; - getArrayDesc(array, Format, NumChannels); + UR_CHECK_ERROR(getArrayDesc(array, Format, NumChannels)); if (Format != HIP_AD_FORMAT_UNSIGNED_INT32 && Format != HIP_AD_FORMAT_SIGNED_INT32 && Format != HIP_AD_FORMAT_HALF && Format != HIP_AD_FORMAT_FLOAT) { diff --git a/source/adapters/hip/usm.cpp b/source/adapters/hip/usm.cpp index 7af7401f87..8854748da9 100644 --- a/source/adapters/hip/usm.cpp +++ b/source/adapters/hip/usm.cpp @@ -73,7 +73,11 @@ UR_APIEXPORT ur_result_t UR_APICALL USMFreeImpl(ur_context_handle_t hContext, ScopedContext Active(hContext->getDevice()); hipPointerAttribute_t hipPointerAttributeType; UR_CHECK_ERROR(hipPointerGetAttributes(&hipPointerAttributeType, pMem)); - unsigned int Type = hipPointerAttributeType.memoryType; +#if HIP_VERSION >= 50600000 + const auto Type = hipPointerAttributeType.type; +#else + const auto Type = hipPointerAttributeType.memoryType; +#endif UR_ASSERT(Type == hipMemoryTypeDevice || Type == hipMemoryTypeHost, UR_RESULT_ERROR_INVALID_MEM_OBJECT); if (Type == hipMemoryTypeDevice) { @@ -171,7 +175,11 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, return ReturnValue(UR_USM_TYPE_SHARED); } UR_CHECK_ERROR(hipPointerGetAttributes(&hipPointerAttributeType, pMem)); +#if HIP_VERSION >= 50600000 + Value = hipPointerAttributeType.type; +#else Value = hipPointerAttributeType.memoryType; +#endif UR_ASSERT(Value == hipMemoryTypeDevice || Value == hipMemoryTypeHost, UR_RESULT_ERROR_INVALID_MEM_OBJECT); if (Value == hipMemoryTypeDevice) {