diff --git a/source/adapters/hip/device.hpp b/source/adapters/hip/device.hpp index e56e5e1293..269848f3cd 100644 --- a/source/adapters/hip/device.hpp +++ b/source/adapters/hip/device.hpp @@ -26,8 +26,13 @@ struct ur_device_handle_t_ { ur_platform_handle_t Platform; hipCtx_t HIPContext; uint32_t DeviceIndex; - int DeviceMaxLocalMem; - int ManagedMemSupport; + int MaxWorkGroupSize{0}; + int MaxBlockDimX{0}; + int MaxBlockDimY{0}; + int MaxBlockDimZ{0}; + int DeviceMaxLocalMem{0}; + int ManagedMemSupport{0}; + int ConcurrentManagedAccess{0}; public: ur_device_handle_t_(native_type HipDevice, hipCtx_t Context, @@ -35,11 +40,22 @@ struct ur_device_handle_t_ { : HIPDevice(HipDevice), RefCount{1}, Platform(Platform), HIPContext(Context), DeviceIndex(DeviceIndex) { + UR_CHECK_ERROR(hipDeviceGetAttribute( + &MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice)); + UR_CHECK_ERROR(hipDeviceGetAttribute( + &MaxBlockDimX, hipDeviceAttributeMaxBlockDimX, HIPDevice)); + UR_CHECK_ERROR(hipDeviceGetAttribute( + &MaxBlockDimY, hipDeviceAttributeMaxBlockDimY, HIPDevice)); + UR_CHECK_ERROR(hipDeviceGetAttribute( + &MaxBlockDimZ, hipDeviceAttributeMaxBlockDimZ, HIPDevice)); UR_CHECK_ERROR(hipDeviceGetAttribute( &DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock, HIPDevice)); UR_CHECK_ERROR(hipDeviceGetAttribute( &ManagedMemSupport, hipDeviceAttributeManagedMemory, HIPDevice)); + UR_CHECK_ERROR(hipDeviceGetAttribute( + &ConcurrentManagedAccess, hipDeviceAttributeConcurrentManagedAccess, + HIPDevice)); } ~ur_device_handle_t_() noexcept(false) { @@ -58,9 +74,21 @@ struct ur_device_handle_t_ { // platform uint32_t getIndex() const noexcept { return DeviceIndex; }; + int getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; }; + + int getMaxBlockDimX() const noexcept { return MaxBlockDimX; }; + + int getMaxBlockDimY() const noexcept { return MaxBlockDimY; }; + + int getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; }; + int getDeviceMaxLocalMem() const noexcept { return DeviceMaxLocalMem; }; int getManagedMemSupport() const noexcept { return ManagedMemSupport; }; + + int getConcurrentManagedAccess() const noexcept { + return ConcurrentManagedAccess; + }; }; int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute); diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index cd8b60e7a2..8219e2ff97 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -300,15 +300,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( bool ProvidedLocalWorkGroupSize = (pLocalWorkSize != nullptr); { - ur_result_t Result = urDeviceGetInfo( - hQueue->Device, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES, - sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr); - UR_ASSERT(Result == UR_RESULT_SUCCESS, Result); + MaxThreadsPerBlock[0] = hQueue->Device->getMaxBlockDimX(); + MaxThreadsPerBlock[1] = hQueue->Device->getMaxBlockDimY(); + MaxThreadsPerBlock[2] = hQueue->Device->getMaxBlockDimZ(); - Result = - urDeviceGetInfo(hQueue->Device, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE, - sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr); - UR_ASSERT(Result == UR_RESULT_SUCCESS, Result); + MaxWorkGroupSize = hQueue->Device->getMaxWorkGroupSize(); // The MaxWorkGroupSize = 1024 for AMD GPU // The MaxThreadsPerBlock = {1024, 1024, 1024} @@ -1480,7 +1476,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( // If the device does not support managed memory access, we can't set // mem_advise. - if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) { + if (!Device->getManagedMemSupport()) { releaseEvent(); setErrorMessage("mem_advise ignored as device does not support " "managed memory access", @@ -1554,7 +1550,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, // If the device does not support managed memory access, we can't set // mem_advise. - if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) { + if (!Device->getManagedMemSupport()) { releaseEvent(); setErrorMessage("mem_advise ignored as device does not support " "managed memory access", @@ -1571,7 +1567,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE | UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE | UR_USM_ADVICE_FLAG_DEFAULT)) { - if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) { + if (!Device->getConcurrentManagedAccess()) { releaseEvent(); setErrorMessage("mem_advise ignored as device does not support " "concurrent managed access",