From af226b6630e66f060f1930212b94448b5fb6fc22 Mon Sep 17 00:00:00 2001 From: Dan Holmes Date: Tue, 12 Dec 2023 18:36:31 +0000 Subject: [PATCH] Fix-ups for first batch of unit tests --- source/loader/ur_lib.cpp | 102 ++++++++++-------- .../device/urDeviceGetSelected.cpp | 18 ++-- 2 files changed, 71 insertions(+), 49 deletions(-) diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index 79079ca85b..c3dfc11616 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -219,7 +219,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, uint32_t *pNumDevices) { if (!hPlatform) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; } // NumEntries is max number of devices wanted by the caller (max usable length of phDevices) if (NumEntries < 0) { @@ -230,9 +230,22 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, } // pNumDevices is the actual number of device handles added to phDevices by this function if (NumEntries == 0 && !pNumDevices) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; + return UR_RESULT_ERROR_INVALID_SIZE; } + switch (DeviceType) { + case UR_DEVICE_TYPE_ALL: + case UR_DEVICE_TYPE_GPU: + case UR_DEVICE_TYPE_DEFAULT: + case UR_DEVICE_TYPE_CPU: + case UR_DEVICE_TYPE_FPGA: + case UR_DEVICE_TYPE_MCA: + break; + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + //urPrint("Unknown device type"); + break; + } // plan: // 0. basic validation of argument values (see code above) // 1. conversion of argument values into useful data items @@ -267,42 +280,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, // however // discard "2,*" == "*,2" - ur_platform_backend_t platformBackend; - if (UR_RESULT_SUCCESS != - urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND, - sizeof(ur_platform_backend_t), &platformBackend, 0)) { - return UR_RESULT_ERROR_INVALID_PLATFORM; - } - const std::string platformBackendName = // hPlatform->get_backend_name(); - [&platformBackend]() constexpr { - switch (platformBackend) { - case UR_PLATFORM_BACKEND_UNKNOWN: - return "*"; // the only ODS string that matches - break; - case UR_PLATFORM_BACKEND_LEVEL_ZERO: - return "level_zero"; - break; - case UR_PLATFORM_BACKEND_OPENCL: - return "opencl"; - break; - case UR_PLATFORM_BACKEND_CUDA: - return "cuda"; - break; - case UR_PLATFORM_BACKEND_HIP: - return "hip"; - break; - case UR_PLATFORM_BACKEND_NATIVE_CPU: - return "*"; // the only ODS string that matches - break; - case UR_PLATFORM_BACKEND_FORCE_UINT32: - return ""; // no ODS string matches this - break; - default: - return ""; // no ODS string matches this - break; - } - }(); - // The std::map is sorted by its key, so this method of parsing the ODS env var // alters the ordering of the terms, which makes it impossible to check whether // all discard terms appear after all accept terms and to preserve the ordering @@ -314,7 +291,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, // discard term, for that backend. // (If we wished to preserve the ordering of terms, we could replace `std::map` // with `std::queue>` or something similar.) - auto &maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true); + auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true); // if the ODS env var is not set at all, then pretend it was set to the default using EnvVarMap = std::map>; @@ -359,6 +336,42 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, "\\*\\.\\*\\.\\*" // C++ and regex escapes, literal '*.*.*' ")$", std::regex_constants::icase); + ur_platform_backend_t platformBackend; + if (UR_RESULT_SUCCESS != + urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND, + sizeof(ur_platform_backend_t), &platformBackend, 0)) { + return UR_RESULT_ERROR_INVALID_PLATFORM; + } + const std::string platformBackendName = // hPlatform->get_backend_name(); + [&platformBackend]() constexpr { + switch (platformBackend) { + case UR_PLATFORM_BACKEND_UNKNOWN: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_LEVEL_ZERO: + return "level_zero"; + break; + case UR_PLATFORM_BACKEND_OPENCL: + return "opencl"; + break; + case UR_PLATFORM_BACKEND_CUDA: + return "cuda"; + break; + case UR_PLATFORM_BACKEND_HIP: + return "hip"; + break; + case UR_PLATFORM_BACKEND_NATIVE_CPU: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_FORCE_UINT32: + return ""; // no ODS string matches this + break; + default: + return ""; // no ODS string matches this + break; + } + }(); + using DeviceHardwareType = ur_device_type_t; enum class DevicePartLevel { @@ -600,7 +613,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, rootDevices.end()); } - // To support sub-device terms: std::for_each( rootDevices.cbegin(), rootDevices.cend(), @@ -780,11 +792,15 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, if (NumEntries == 0) { *pNumDevices = static_cast(selectedDevices.size()); } else if (NumEntries > 0) { - *pNumDevices = static_cast( - std::min((size_t)NumEntries, selectedDevices.size())); - std::copy_n(selectedDevices.cbegin(), *pNumDevices, phDevices); + size_t numToCopy = std::min((size_t)NumEntries, selectedDevices.size()); + std::copy_n(selectedDevices.cbegin(), numToCopy, phDevices); + if (pNumDevices != nullptr) { + *pNumDevices = static_cast(numToCopy); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } } + return UR_RESULT_SUCCESS; } } // namespace ur_lib diff --git a/test/conformance/device/urDeviceGetSelected.cpp b/test/conformance/device/urDeviceGetSelected.cpp index 0cb2738cf3..a5600e4eb7 100644 --- a/test/conformance/device/urDeviceGetSelected.cpp +++ b/test/conformance/device/urDeviceGetSelected.cpp @@ -9,13 +9,18 @@ using urDeviceGetSelectedTest = uur::urPlatformTest; TEST_F(urDeviceGetSelectedTest, Success) { uint32_t count = 0; - ASSERT_SUCCESS( - urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ur_result_t res1 = + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count); + ASSERT_EQ_RESULT(res1, UR_RESULT_SUCCESS); ASSERT_NE(count, 0); std::vector devices(count); - ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, - devices.data(), nullptr)); - for (auto device : devices) { + ASSERT_NE(devices.size(), 0); + ASSERT_NE(devices.data(), nullptr); + //FAIL() << "devices.size() = " << devices.size() << " whereas count = " << count; + ur_result_t res2 = urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr); + ASSERT_EQ_RESULT(res2, UR_RESULT_SUCCESS); + for (auto &device : devices) { ASSERT_NE(nullptr, device); } } @@ -25,7 +30,8 @@ TEST_F(urDeviceGetSelectedTest, SuccessSubsetOfDevices) { ASSERT_SUCCESS( urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); if (count < 2) { - GTEST_SKIP(); + GTEST_SKIP() << "There are fewer than two devices in total for the " + "platform so the subset test is impossible"; } std::vector devices(count - 1); ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count - 1,