Skip to content

Commit

Permalink
Merge pull request #1796 from GeorgeWeb/georgi/ur_kernel_max_active_wgs
Browse files Browse the repository at this point in the history
[CUDA] Implement urKernelSuggestMaxCooperativeGroupCountExp for Cuda
  • Loading branch information
omarahmed1111 authored Sep 10, 2024
2 parents e26bba5 + 45a781f commit eb63d1a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
7 changes: 1 addition & 6 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
return ReturnValue(4318u);
}
case UR_DEVICE_INFO_MAX_COMPUTE_UNITS: {
int ComputeUnits = 0;
UR_CHECK_ERROR(cuDeviceGetAttribute(
&ComputeUnits, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
hDevice->get()));
detail::ur::assertion(ComputeUnits >= 0);
return ReturnValue(static_cast<uint32_t>(ComputeUnits));
return ReturnValue(hDevice->getNumComputeUnits());
}
case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: {
return ReturnValue(MaxWorkItemDimensions);
Expand Down
7 changes: 7 additions & 0 deletions source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct ur_device_handle_t_ {
int MaxCapacityLocalMem{0};
int MaxChosenLocalMem{0};
bool MaxLocalMemSizeChosen{false};
uint32_t NumComputeUnits{0};

public:
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
Expand All @@ -54,6 +55,10 @@ struct ur_device_handle_t_ {
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize,
nullptr));

UR_CHECK_ERROR(cuDeviceGetAttribute(
reinterpret_cast<int *>(&NumComputeUnits),
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, cuDevice));

// Set local mem max size if env var is present
static const char *LocalMemSizePtrUR =
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE");
Expand Down Expand Up @@ -107,6 +112,8 @@ struct ur_device_handle_t_ {
int getMaxChosenLocalMem() const noexcept { return MaxChosenLocalMem; };

bool maxLocalMemSizeChosen() { return MaxLocalMemSizeChosen; };

uint32_t getNumComputeUnits() const noexcept { return NumComputeUnits; };
};

int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);
44 changes: 40 additions & 4 deletions source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, size_t localWorkSize,
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
(void)hKernel;
(void)localWorkSize;
(void)dynamicSharedMemorySize;
*pGroupCountRet = 1;
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);

// We need to set the active current device for this kernel explicitly here,
// because the occupancy querying API does not take device parameter.
ur_device_handle_t Device = hKernel->getProgram()->getDevice();
ScopedContext Active(Device);
try {
// We need to calculate max num of work-groups using per-device semantics.

int MaxNumActiveGroupsPerCU{0};
UR_CHECK_ERROR(cuOccupancyMaxActiveBlocksPerMultiprocessor(
&MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize,
dynamicSharedMemorySize));
detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0);
// Handle the case where we can't have all SMs active with at least 1 group
// per SM. In that case, the device is still able to run 1 work-group, hence
// we will manually check if it is possible with the available HW resources.
if (MaxNumActiveGroupsPerCU == 0) {
size_t MaxWorkGroupSize{};
urKernelGetGroupInfo(
hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
size_t MaxLocalSizeBytes{};
urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr);
if (localWorkSize > MaxWorkGroupSize ||
dynamicSharedMemorySize > MaxLocalSizeBytes ||
hasExceededMaxRegistersPerBlock(Device, hKernel, localWorkSize))
*pGroupCountRet = 0;
else
*pGroupCountRet = 1;
} else {
// Multiply by the number of SMs (CUs = compute units) on the device in
// order to retreive the total number of groups/blocks that can be
// launched.
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
}
} catch (ur_result_t Err) {
return Err;
}
return UR_RESULT_SUCCESS;
}

Expand Down

0 comments on commit eb63d1a

Please sign in to comment.