Skip to content

Commit

Permalink
Fix-ups for first batch of unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Wee-Free-Scot committed Dec 12, 2023
1 parent 22578d2 commit af226b6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 49 deletions.
102 changes: 59 additions & 43 deletions source/loader/ur_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
uint32_t *pNumDevices) {

if (!hPlatform) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}
// NumEntries is max number of devices wanted by the caller (max usable length of phDevices)
if (NumEntries < 0) {
Expand All @@ -230,9 +230,22 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
}
// pNumDevices is the actual number of device handles added to phDevices by this function
if (NumEntries == 0 && !pNumDevices) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
return UR_RESULT_ERROR_INVALID_SIZE;
}

switch (DeviceType) {
case UR_DEVICE_TYPE_ALL:
case UR_DEVICE_TYPE_GPU:
case UR_DEVICE_TYPE_DEFAULT:
case UR_DEVICE_TYPE_CPU:
case UR_DEVICE_TYPE_FPGA:
case UR_DEVICE_TYPE_MCA:
break;
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
//urPrint("Unknown device type");
break;
}
// plan:
// 0. basic validation of argument values (see code above)
// 1. conversion of argument values into useful data items
Expand Down Expand Up @@ -267,42 +280,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
// however
// discard "2,*" == "*,2"

ur_platform_backend_t platformBackend;
if (UR_RESULT_SUCCESS !=
urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND,
sizeof(ur_platform_backend_t), &platformBackend, 0)) {
return UR_RESULT_ERROR_INVALID_PLATFORM;
}
const std::string platformBackendName = // hPlatform->get_backend_name();
[&platformBackend]() constexpr {
switch (platformBackend) {
case UR_PLATFORM_BACKEND_UNKNOWN:
return "*"; // the only ODS string that matches
break;
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
return "level_zero";
break;
case UR_PLATFORM_BACKEND_OPENCL:
return "opencl";
break;
case UR_PLATFORM_BACKEND_CUDA:
return "cuda";
break;
case UR_PLATFORM_BACKEND_HIP:
return "hip";
break;
case UR_PLATFORM_BACKEND_NATIVE_CPU:
return "*"; // the only ODS string that matches
break;
case UR_PLATFORM_BACKEND_FORCE_UINT32:
return ""; // no ODS string matches this
break;
default:
return ""; // no ODS string matches this
break;
}
}();

// The std::map is sorted by its key, so this method of parsing the ODS env var
// alters the ordering of the terms, which makes it impossible to check whether
// all discard terms appear after all accept terms and to preserve the ordering
Expand All @@ -314,7 +291,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
// discard term, for that backend.
// (If we wished to preserve the ordering of terms, we could replace `std::map`
// with `std::queue<std::pair<key_type_t, value_type_t>>` or something similar.)
auto &maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true);
auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true);

// if the ODS env var is not set at all, then pretend it was set to the default
using EnvVarMap = std::map<std::string, std::vector<std::string>>;
Expand Down Expand Up @@ -359,6 +336,42 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
"\\*\\.\\*\\.\\*" // C++ and regex escapes, literal '*.*.*'
")$", std::regex_constants::icase);

ur_platform_backend_t platformBackend;
if (UR_RESULT_SUCCESS !=
urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND,
sizeof(ur_platform_backend_t), &platformBackend, 0)) {
return UR_RESULT_ERROR_INVALID_PLATFORM;
}
const std::string platformBackendName = // hPlatform->get_backend_name();
[&platformBackend]() constexpr {
switch (platformBackend) {
case UR_PLATFORM_BACKEND_UNKNOWN:
return "*"; // the only ODS string that matches
break;
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
return "level_zero";
break;
case UR_PLATFORM_BACKEND_OPENCL:
return "opencl";
break;
case UR_PLATFORM_BACKEND_CUDA:
return "cuda";
break;
case UR_PLATFORM_BACKEND_HIP:
return "hip";
break;
case UR_PLATFORM_BACKEND_NATIVE_CPU:
return "*"; // the only ODS string that matches
break;
case UR_PLATFORM_BACKEND_FORCE_UINT32:
return ""; // no ODS string matches this
break;
default:
return ""; // no ODS string matches this
break;
}
}();

using DeviceHardwareType = ur_device_type_t;

enum class DevicePartLevel {
Expand Down Expand Up @@ -600,7 +613,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
rootDevices.end());
}


// To support sub-device terms:
std::for_each(
rootDevices.cbegin(), rootDevices.cend(),
Expand Down Expand Up @@ -780,11 +792,15 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
if (NumEntries == 0) {
*pNumDevices = static_cast<uint32_t>(selectedDevices.size());
} else if (NumEntries > 0) {
*pNumDevices = static_cast<uint32_t>(
std::min((size_t)NumEntries, selectedDevices.size()));
std::copy_n(selectedDevices.cbegin(), *pNumDevices, phDevices);
size_t numToCopy = std::min((size_t)NumEntries, selectedDevices.size());
std::copy_n(selectedDevices.cbegin(), numToCopy, phDevices);
if (pNumDevices != nullptr) {
*pNumDevices = static_cast<uint32_t>(numToCopy);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
}


return UR_RESULT_SUCCESS;
}
} // namespace ur_lib
18 changes: 12 additions & 6 deletions test/conformance/device/urDeviceGetSelected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ using urDeviceGetSelectedTest = uur::urPlatformTest;

TEST_F(urDeviceGetSelectedTest, Success) {
uint32_t count = 0;
ASSERT_SUCCESS(
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count));
ur_result_t res1 =
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count);
ASSERT_EQ_RESULT(res1, UR_RESULT_SUCCESS);
ASSERT_NE(count, 0);
std::vector<ur_device_handle_t> devices(count);
ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count,
devices.data(), nullptr));
for (auto device : devices) {
ASSERT_NE(devices.size(), 0);
ASSERT_NE(devices.data(), nullptr);
//FAIL() << "devices.size() = " << devices.size() << " whereas count = " << count;
ur_result_t res2 = urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count,
devices.data(), nullptr);
ASSERT_EQ_RESULT(res2, UR_RESULT_SUCCESS);
for (auto &device : devices) {
ASSERT_NE(nullptr, device);
}
}
Expand All @@ -25,7 +30,8 @@ TEST_F(urDeviceGetSelectedTest, SuccessSubsetOfDevices) {
ASSERT_SUCCESS(
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count));
if (count < 2) {
GTEST_SKIP();
GTEST_SKIP() << "There are fewer than two devices in total for the "
"platform so the subset test is impossible";
}
std::vector<ur_device_handle_t> devices(count - 1);
ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count - 1,
Expand Down

0 comments on commit af226b6

Please sign in to comment.