Skip to content

Commit

Permalink
Merge pull request #1217 from konradkusiak97/cachHIPcallsEnqueue
Browse files Browse the repository at this point in the history
[HIP] Cache some of the HIP driver calls from kernel enqueue
  • Loading branch information
kbenzie committed Feb 2, 2024
2 parents 76a2a9d + 6b96993 commit bd745d1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
43 changes: 42 additions & 1 deletion source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@ struct ur_device_handle_t_ {
ur_platform_handle_t Platform;
hipCtx_t HIPContext;
uint32_t DeviceIndex;
int MaxWorkGroupSize{0};
int MaxBlockDimX{0};
int MaxBlockDimY{0};
int MaxBlockDimZ{0};
int DeviceMaxLocalMem{0};
int ManagedMemSupport{0};
int ConcurrentManagedAccess{0};

public:
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
ur_platform_handle_t Platform, uint32_t DeviceIndex)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context), DeviceIndex(DeviceIndex) {}
HIPContext(Context), DeviceIndex(DeviceIndex) {

UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimX, hipDeviceAttributeMaxBlockDimX, HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimY, hipDeviceAttributeMaxBlockDimY, HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimZ, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&ManagedMemSupport, hipDeviceAttributeManagedMemory, HIPDevice));
UR_CHECK_ERROR(hipDeviceGetAttribute(
&ConcurrentManagedAccess, hipDeviceAttributeConcurrentManagedAccess,
HIPDevice));
}

~ur_device_handle_t_() noexcept(false) {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
Expand All @@ -48,6 +73,22 @@ struct ur_device_handle_t_ {
// Returns the index of the device relative to the other devices in the same
// platform
uint32_t getIndex() const noexcept { return DeviceIndex; };

int getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };

int getMaxBlockDimX() const noexcept { return MaxBlockDimX; };

int getMaxBlockDimY() const noexcept { return MaxBlockDimY; };

int getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; };

int getDeviceMaxLocalMem() const noexcept { return DeviceMaxLocalMem; };

int getManagedMemSupport() const noexcept { return ManagedMemSupport; };

int getConcurrentManagedAccess() const noexcept {
return ConcurrentManagedAccess;
};
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
24 changes: 8 additions & 16 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
bool ProvidedLocalWorkGroupSize = (pLocalWorkSize != nullptr);

{
ur_result_t Result = urDeviceGetInfo(
hQueue->Device, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr);
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
MaxThreadsPerBlock[0] = hQueue->Device->getMaxBlockDimX();
MaxThreadsPerBlock[1] = hQueue->Device->getMaxBlockDimY();
MaxThreadsPerBlock[2] = hQueue->Device->getMaxBlockDimZ();

Result =
urDeviceGetInfo(hQueue->Device, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
MaxWorkGroupSize = hQueue->Device->getMaxWorkGroupSize();

// The MaxWorkGroupSize = 1024 for AMD GPU
// The MaxThreadsPerBlock = {1024, 1024, 1024}
Expand Down Expand Up @@ -423,11 +419,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
: (LocalMemSzPtrPI ? LocalMemSzPtrPI : nullptr);

if (LocalMemSzPtr) {
int DeviceMaxLocalMem = 0;
UR_CHECK_ERROR(hipDeviceGetAttribute(
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
Dev->get()));

int DeviceMaxLocalMem = Dev->getDeviceMaxLocalMem();
static const int EnvVal = std::atoi(LocalMemSzPtr);
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {
setErrorMessage(LocalMemSzPtrUR ? "Invalid value specified for "
Expand Down Expand Up @@ -1484,7 +1476,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(

// If the device does not support managed memory access, we can't set
// mem_advise.
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
if (!Device->getManagedMemSupport()) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"managed memory access",
Expand Down Expand Up @@ -1558,7 +1550,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,

// If the device does not support managed memory access, we can't set
// mem_advise.
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
if (!Device->getManagedMemSupport()) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"managed memory access",
Expand All @@ -1575,7 +1567,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
UR_USM_ADVICE_FLAG_DEFAULT)) {
if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) {
if (!Device->getConcurrentManagedAccess()) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"concurrent managed access",
Expand Down

0 comments on commit bd745d1

Please sign in to comment.