diff --git a/include/ur_api.h b/include/ur_api.h index 8579ff0326..89094bc9ff 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -220,6 +220,7 @@ typedef enum ur_function_t { UR_FUNCTION_COMMAND_BUFFER_RELEASE_COMMAND_EXP = 217, ///< Enumerator for ::urCommandBufferReleaseCommandExp UR_FUNCTION_COMMAND_BUFFER_GET_INFO_EXP = 218, ///< Enumerator for ::urCommandBufferGetInfoExp UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP = 219, ///< Enumerator for ::urCommandBufferCommandGetInfoExp + UR_FUNCTION_DEVICE_GET_SELECTED = 220, ///< Enumerator for ::urDeviceGetSelected /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -1387,6 +1388,46 @@ urDeviceGet( ///< pNumDevices will be updated with the total number of devices available. ); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +UR_APIEXPORT ur_result_t UR_APICALL +urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t *phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. +); + /////////////////////////////////////////////////////////////////////////////// /// @brief Supported device info typedef enum ur_device_info_t { @@ -11148,6 +11189,18 @@ typedef struct ur_device_get_params_t { uint32_t **ppNumDevices; } ur_device_get_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urDeviceGetSelected +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_device_get_selected_params_t { + ur_platform_handle_t *phPlatform; + ur_device_type_t *pDeviceType; + uint32_t *pNumEntries; + ur_device_handle_t **pphDevices; + uint32_t **ppNumDevices; +} ur_device_get_selected_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urDeviceGetInfo /// @details Each entry is a pointer to the parameter passed to the function; diff --git a/include/ur_print.h b/include/ur_print.h index e1718e99f8..b4675aee02 100644 --- a/include/ur_print.h +++ b/include/ur_print.h @@ -2450,6 +2450,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintVirtualMemGetInfoParams(const struct /// - `buff_size < out_size` UR_APIEXPORT ur_result_t UR_APICALL urPrintDeviceGetParams(const struct ur_device_get_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_device_get_selected_params_t struct +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintDeviceGetSelectedParams(const struct ur_device_get_selected_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_device_get_info_params_t struct /// @returns diff --git a/include/ur_print.hpp b/include/ur_print.hpp index b4c777b77d..3e665176d9 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -909,6 +909,9 @@ inline std::ostream &operator<<(std::ostream &os, ur_function_t value) { case UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP: os << "UR_FUNCTION_COMMAND_BUFFER_COMMAND_GET_INFO_EXP"; break; + case UR_FUNCTION_DEVICE_GET_SELECTED: + os << "UR_FUNCTION_DEVICE_GET_SELECTED"; + break; default: os << "unknown enumerator"; break; @@ -16282,6 +16285,48 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_device_get_selected_params_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_device_get_selected_params_t *params) { + + os << ".hPlatform = "; + + ur::details::printPtr(os, + *(params->phPlatform)); + + os << ", "; + os << ".DeviceType = "; + + os << *(params->pDeviceType); + + os << ", "; + os << ".NumEntries = "; + + os << *(params->pNumEntries); + + os << ", "; + os << ".phDevices = {"; + for (size_t i = 0; *(params->pphDevices) != NULL && i < *params->pNumEntries; ++i) { + if (i != 0) { + os << ", "; + } + + ur::details::printPtr(os, + (*(params->pphDevices))[i]); + } + os << "}"; + + os << ", "; + os << ".pNumDevices = "; + + ur::details::printPtr(os, + *(params->ppNumDevices)); + + return os; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_device_get_info_params_t type /// @returns @@ -17080,6 +17125,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_ case UR_FUNCTION_DEVICE_GET: { os << (const struct ur_device_get_params_t *)params; } break; + case UR_FUNCTION_DEVICE_GET_SELECTED: { + os << (const struct ur_device_get_selected_params_t *)params; + } break; case UR_FUNCTION_DEVICE_GET_INFO: { os << (const struct ur_device_get_info_params_t *)params; } break; diff --git a/scripts/core/device.yml b/scripts/core/device.yml index 61b004a4d0..8acdfafed5 100644 --- a/scripts/core/device.yml +++ b/scripts/core/device.yml @@ -150,6 +150,45 @@ returns: - "`NumEntries > 0 && phDevices == NULL`" - $X_RESULT_ERROR_INVALID_VALUE --- #-------------------------------------------------------------------------- +type: function +desc: "Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR" +class: $xDevice +loader_only: True +name: GetSelected +decl: static +ordinal: "0" +details: + - "Multiple calls to this function will return identical device handles, in the same order." + - "The number and order of handles returned from this function will be affected by environment variables that filter or select which devices are exposed through this API." + - "A reference is taken for each returned device and must be released with a subsequent call to $xDeviceRelease." + - "The application may call this function from simultaneous threads, the implementation must be thread-safe." +params: + - type: $x_platform_handle_t + name: hPlatform + desc: "[in] handle of the platform instance" + - type: "$x_device_type_t" + name: DeviceType + desc: | + [in] the type of the devices. + - type: "uint32_t" + name: NumEntries + desc: | + [in] the number of devices to be added to phDevices. + If phDevices in not NULL then NumEntries should be greater than zero, otherwise $X_RESULT_ERROR_INVALID_VALUE, + will be returned. + - type: "$x_device_handle_t*" + name: phDevices + desc: | + [out][optional][range(0, NumEntries)] array of handle of devices. + If NumEntries is less than the number of devices available, then only that number of devices will be retrieved. + - type: "uint32_t*" + name: pNumDevices + desc: | + [out][optional] pointer to the number of devices. + pNumDevices will be updated with the total number of selected devices available for the given platform. +returns: + - $X_RESULT_ERROR_INVALID_VALUE +--- #-------------------------------------------------------------------------- type: enum desc: "Supported device info" class: $xDevice diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 363531580f..3e86e109c3 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -574,6 +574,9 @@ etors: - name: COMMAND_BUFFER_COMMAND_GET_INFO_EXP desc: Enumerator for $xCommandBufferCommandGetInfoExp value: '219' +- name: DEVICE_GET_SELECTED + desc: Enumerator for $xDeviceGetSelected + value: '220' --- type: enum desc: Defines structure types diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index 34531ca8b1..0e350ef3fa 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -1,6 +1,6 @@ /* * - * Copyright (C) 2022-2023 Intel Corporation + * Copyright (C) 2024 Intel Corporation * * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. * See LICENSE.TXT @@ -9,11 +9,20 @@ * @file ur_lib.cpp * */ + +// avoids windows.h from defining macros for min and max +// which avoids playing havoc with std::min and std::max +// (not quite sure why windows.h is being included here) +#ifndef NOMINMAX +#define NOMINMAX +#endif // !NOMINMAX + #include "ur_lib.hpp" #include "logger/ur_logger.hpp" #include "ur_loader.hpp" -#include +#include // for std::memcpy +#include namespace ur_lib { /////////////////////////////////////////////////////////////////////////////// @@ -206,4 +215,618 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, return UR_RESULT_SUCCESS; } +ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, + ur_device_type_t DeviceType, + uint32_t NumEntries, + ur_device_handle_t *phDevices, + uint32_t *pNumDevices) { + + if (!hPlatform) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + // NumEntries is max number of devices wanted by the caller (max usable length of phDevices) + if (NumEntries < 0) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + if (NumEntries > 0 && !phDevices) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + // pNumDevices is the actual number of device handles added to phDevices by this function + if (NumEntries == 0 && !pNumDevices) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + switch (DeviceType) { + case UR_DEVICE_TYPE_ALL: + case UR_DEVICE_TYPE_GPU: + case UR_DEVICE_TYPE_DEFAULT: + case UR_DEVICE_TYPE_CPU: + case UR_DEVICE_TYPE_FPGA: + case UR_DEVICE_TYPE_MCA: + break; + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + //urPrint("Unknown device type"); + break; + } + // plan: + // 0. basic validation of argument values (see code above) + // 1. conversion of argument values into useful data items + // 2. retrieval and parsing of environment variable string + // 3. conversion of term map to accept and discard filters + // 4. inserting a default "*:*" accept filter, if required + // 5. symbolic consolidation of accept and discard filters + // 6. querying the platform handles for all 'root' devices + // 7. partioning via platform root devices into subdevices + // 8. partioning via platform subdevices into subsubdevices + // 9. short-listing devices to accept using accept filters + // A. de-listing devices to discard using discard filters + + // possible symbolic short-circuit special cases exist: + // * if there are no terms, select all root devices + // * if any discard is "*", select no root devices + // * if any discard is "*.*", select no sub-devices + // * if any discard is "*.*.*", select no sub-sub-devices + // * + // + // detail for step 5 of above plan: + // * combine all accept filters into a single accept list + // * combine all discard filters into single discard list + // then invert it to make the initial/default accept list + // (needs knowledge of the valid range from the platform) + // "!level_zero:1,2" -> "level_zero:0,3,...,max" + // * finally subtract the discard set from the accept set + + // accept "2,*" != "*,2" + // because "2,*" == "2,0,1,3" + // whereas "*,2" == "0,1,2,3" + // however + // discard "2,*" == "*,2" + + // The std::map is sorted by its key, so this method of parsing the ODS env var + // alters the ordering of the terms, which makes it impossible to check whether + // all discard terms appear after all accept terms and to preserve the ordering + // of backends as specified in the ODS string. + // However, for single-platform requests, we are only interested in exactly one + // backend, and we know that discard filter terms always override accept filter + // terms, so the ordering of terms can be safely ignored -- in the special case + // where the whole ODS string contains at most one accept term, and at most one + // discard term, for that backend. + // (If we wished to preserve the ordering of terms, we could replace `std::map` + // with `std::queue>` or something similar.) + auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", false); + logger::debug( + "getenv_to_map parsed env var and {} a map", + (maybeEnvVarMap.has_value() ? "produced" : "failed to produce")); + + // if the ODS env var is not set at all, then pretend it was set to the default + using EnvVarMap = std::map>; + EnvVarMap mapODS = maybeEnvVarMap.has_value() ? maybeEnvVarMap.value() + : EnvVarMap{{"*", {"*"}}}; + + // the full BNF grammar can be found here: + // https://github.com/intel/llvm/blob/sycl/sycl/doc/EnvironmentVariables.md#oneapi_device_selector + + // discardFilter = "!acceptFilter" + // acceptFilter = "backend:filterStrings" + // filterStrings = "filterString[,filterString[,...]]" + // filterString = "root[.sub[.subsub]]" + // root = "*|int|cpu|gpu|fpga" + // sub = "*|int" + // subsub = "*|int" + + // validation regex for filterString (not used in this code) + std::regex validation_pattern( + "^(" + "\\*" // C++ escape for \, regex escape for literal '*' + "|" + "cpu" // ensure case-insenitive, when using + "|" + "gpu" // ensure case-insenitive, when using + "|" + "fpga" // ensure case-insenitive, when using + "|" + "[[:digit:]]+" // '' + "|" + "[[:digit:]]+\\.[[:digit:]]+" // '.' + "|" + "[[:digit:]]+\\.\\*" // '.*.*' + "|" + "\\*\\.\\*" // C++ and regex escapes, literal '*.*' + "|" + "[[:digit:]]+\\.[[:digit:]]+\\.[[:digit:]]+" // '..' + "|" + "[[:digit:]]+\\.[[:digit:]]+\\.\\*" // '..*' + "|" + "[[:digit:]]+\\.\\*\\.\\*" // '.*.*' + "|" + "\\*\\.\\*\\.\\*" // C++ and regex escapes, literal '*.*.*' + ")$", + std::regex_constants::icase); + + ur_platform_backend_t platformBackend; + if (UR_RESULT_SUCCESS != + urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND, + sizeof(ur_platform_backend_t), &platformBackend, 0)) { + return UR_RESULT_ERROR_INVALID_PLATFORM; + } + const std::string platformBackendName = // hPlatform->get_backend_name(); + [&platformBackend]() constexpr { + switch (platformBackend) { + case UR_PLATFORM_BACKEND_UNKNOWN: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_LEVEL_ZERO: + return "level_zero"; + break; + case UR_PLATFORM_BACKEND_OPENCL: + return "opencl"; + break; + case UR_PLATFORM_BACKEND_CUDA: + return "cuda"; + break; + case UR_PLATFORM_BACKEND_HIP: + return "hip"; + break; + case UR_PLATFORM_BACKEND_NATIVE_CPU: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_FORCE_UINT32: + return ""; // no ODS string matches this + break; + default: + return ""; // no ODS string matches this + break; + } + }(); + + using DeviceHardwareType = ur_device_type_t; + + enum class DevicePartLevel { ROOT, SUB, SUBSUB }; + + using DeviceIdType = unsigned long; + constexpr DeviceIdType DeviceIdTypeALL = + -1; // ULONG_MAX but without #include + + struct DeviceSpec { + DevicePartLevel level; + DeviceHardwareType hwType = ::UR_DEVICE_TYPE_ALL; + DeviceIdType rootId = DeviceIdTypeALL; + DeviceIdType subId = DeviceIdTypeALL; + DeviceIdType subsubId = DeviceIdTypeALL; + ur_device_handle_t urDeviceHandle; + }; + + auto getRootHardwareType = + [](const std::string &input) -> DeviceHardwareType { + std::string lowerInput(input); + std::transform(lowerInput.cbegin(), lowerInput.cend(), + lowerInput.begin(), ::tolower); + if (lowerInput == "cpu") { + return ::UR_DEVICE_TYPE_CPU; + } + if (lowerInput == "gpu") { + return ::UR_DEVICE_TYPE_GPU; + } + if (lowerInput == "fpga") { + return ::UR_DEVICE_TYPE_FPGA; + } + return ::UR_DEVICE_TYPE_ALL; + }; + + auto getDeviceId = [&](const std::string &input) -> DeviceIdType { + if (input.find_first_not_of("0123456789") == std::string::npos) { + return std::stoul(input); + } + return DeviceIdTypeALL; + }; + + std::vector acceptDeviceList; + std::vector discardDeviceList; + + for (auto &termPair : mapODS) { + std::string backend = termPair.first; + if (backend + .empty()) { // FIXME: never true because getenv_to_map rejects this case + // malformed term: missing backend -- output ERROR, then continue + logger::error("ERROR: missing backend, format of filter = " + "'[!]backend:filterStrings'"); + continue; + } + enum FilterType { + AcceptFilter, + DiscardFilter, + } termType = (backend.front() != '!') ? AcceptFilter : DiscardFilter; + logger::debug( + "termType is {}", + (termType != AcceptFilter ? "DiscardFilter" : "AcceptFilter")); + auto &deviceList = + (termType != AcceptFilter) ? discardDeviceList : acceptDeviceList; + if (termType != AcceptFilter) { + logger::debug("DEBUG: backend was '{}'", backend); + backend.erase(backend.cbegin()); + logger::debug("DEBUG: backend now '{}'", backend); + } + // Note the hPlatform -> platformBackend -> platformBackendName conversion above + // guarantees minimal sanity for the comparison with backend from the ODS string + if (backend.front() != '*' && + !std::equal(platformBackendName.cbegin(), + platformBackendName.cend(), backend.cbegin(), + backend.cend(), [](const auto &a, const auto &b) { + // case-insensitive comparison by converting both tolower + return std::tolower( + static_cast(a)) == + std::tolower(static_cast(b)); + })) { + // irrelevant term for current request: different backend -- silently ignore + logger::warning( + "WARNING: ignoring term with irrelevant backend '{}'", backend); + continue; + } + if (termPair.second.size() == 0) { + // malformed term: missing filterStrings -- output ERROR, then continue + logger::error("ERROR missing filterStrings, format of filter = " + "'[!]backend:filterStrings'"); + continue; + } + if (std::find_if(termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { return s.empty(); }) != + termPair.second + .cend()) { // FIXME: never true because getenv_to_map rejects this case + // malformed term: missing filterString -- output warning, then continue + logger::warning( + "WARNING: empty filterString, format of filterStrings " + "= 'filterString[,filterString[,...]]'"); + continue; + } + if (std::find_if(termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { + return std::count(s.cbegin(), s.cend(), '.') > 2; + }) != termPair.second.cend()) { + // malformed term: too many dots in filterString -- output warning, then continue + logger::warning("WARNING: too many dots in filterString, format of " + "filterString = 'root[.sub[.subsub]]'"); + continue; + } + if (std::find_if( + termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { + // GOOD: "*.*", "1.*.*", "*.*.*" + // BAD: "*.1", "*.", "1.*.2", "*.gpu" + std::string prefix = "*."; // every "*." pattern ... + std::string whole = "*.*"; // ... must be start of "*.*" + std::string::size_type pos = 0; + while ((pos = s.find(prefix, pos)) != std::string::npos) { + if (s.substr(pos, whole.size()) != whole) { + return true; // found a BAD thing, either "\*\.$" or "\*\.[^*]" + } + pos += prefix.size(); + } + return false; // no BAD things, so must be okay + }) != termPair.second.cend()) { + // malformed term: star dot no-star in filterString -- output warning, then continue + logger::warning( + "WARNING: invalid wildcard in filterString, '*.' => '*.*'"); + continue; + } + + // TODO -- use regex validation_pattern to catch all other syntax errors in the ODS string + + for (auto &filterString : termPair.second) { + std::string::size_type locationDot1 = filterString.find('.'); + if (locationDot1 != std::string::npos) { + std::string firstPart = filterString.substr(0, locationDot1); + const auto hardwareType = getRootHardwareType(firstPart); + const auto firstDeviceId = getDeviceId(firstPart); + // first dot found, look for another + std::string::size_type locationDot2 = + filterString.find('.', locationDot1 + 1); + std::string secondPart = filterString.substr( + locationDot1 + 1, locationDot2 == std::string::npos + ? std::string::npos + : locationDot2 - locationDot1); + const auto secondDeviceId = getDeviceId(secondPart); + if (locationDot2 != std::string::npos) { + // second dot found, this is a subsubdevice + std::string thirdPart = + filterString.substr(locationDot2 + 1); + const auto thirdDeviceId = getDeviceId(thirdPart); + deviceList.push_back(DeviceSpec{ + DevicePartLevel::SUBSUB, hardwareType, firstDeviceId, + secondDeviceId, thirdDeviceId}); + } else { + // second dot not found, this is a subdevice + deviceList.push_back(DeviceSpec{DevicePartLevel::SUB, + hardwareType, firstDeviceId, + secondDeviceId}); + } + } else { + // first dot not found, this is a root device + const auto hardwareType = getRootHardwareType(filterString); + const auto firstDeviceId = getDeviceId(filterString); + deviceList.push_back(DeviceSpec{DevicePartLevel::ROOT, + hardwareType, firstDeviceId}); + } + } + } + + if (acceptDeviceList.size() == 0 && discardDeviceList.size() == 0) { + // nothing in env var was understood as a valid term + return UR_RESULT_ERROR_INVALID_VALUE; + } else if (acceptDeviceList.size() == 0) { + // no accept terms were understood, but at least one discard term was + // we are magnanimous to the user when there were bad/ignored accept terms + // by pretending there were no bad/ignored accept terms in the env var + // for example, we pretend that "garbage:0;!cuda:*" was just "!cuda:*" + // so we add an implicit accept-all term (equivalent to prepending "*:*;") + // as we would have done if the user had given us the corrected string + acceptDeviceList.push_back(DeviceSpec{ + DevicePartLevel::ROOT, ::UR_DEVICE_TYPE_ALL, DeviceIdTypeALL}); + } + + logger::debug("DEBUG: size of acceptDeviceList = {}", + acceptDeviceList.size()); + logger::debug("DEBUG: size of discardDeviceList = {}", + discardDeviceList.size()); + + std::vector rootDevices; + std::vector subDevices; + std::vector subSubDevices; + + // To support root device terms: + { + uint32_t platformNumRootDevicesAll = 0; + if (UR_RESULT_SUCCESS != urDeviceGet(hPlatform, UR_DEVICE_TYPE_ALL, 0, + nullptr, + &platformNumRootDevicesAll)) { + return UR_RESULT_ERROR_DEVICE_NOT_FOUND; + } + std::vector rootDeviceHandles( + platformNumRootDevicesAll); + auto pRootDevices = rootDeviceHandles.data(); + if (UR_RESULT_SUCCESS != urDeviceGet(hPlatform, UR_DEVICE_TYPE_ALL, + platformNumRootDevicesAll, + pRootDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_NOT_FOUND; + } + + DeviceIdType deviceCount = 0; + std::transform( + rootDeviceHandles.cbegin(), rootDeviceHandles.cend(), + std::back_inserter(rootDevices), + [&](ur_device_handle_t urDeviceHandle) { + // obtain and record device type from platform (squash errors) + ur_device_type_t hardwareType = ::UR_DEVICE_TYPE_DEFAULT; + urDeviceGetInfo(urDeviceHandle, UR_DEVICE_INFO_TYPE, + sizeof(ur_device_type_t), &hardwareType, 0); + return DeviceSpec{DevicePartLevel::ROOT, hardwareType, + deviceCount++, DeviceIdTypeALL, + DeviceIdTypeALL, urDeviceHandle}; + }); + + // apply the function parameter: ur_device_type_t DeviceType + // remove_if(..., urDeviceHandle->deviceType == DeviceType) + rootDevices.erase( + std::remove_if( + rootDevices.begin(), rootDevices.end(), + [DeviceType](DeviceSpec &device) { + const bool keep = + (DeviceType == + DeviceHardwareType::UR_DEVICE_TYPE_ALL) || + (DeviceType == + DeviceHardwareType::UR_DEVICE_TYPE_DEFAULT) || + (DeviceType == device.hwType); + return !keep; + }), + rootDevices.end()); + } + + // To support sub-device terms: + std::for_each( + rootDevices.cbegin(), rootDevices.cend(), [&](DeviceSpec device) { + ur_device_partition_property_t propNextPart{ + UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN, + {UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE}}; + ur_device_partition_properties_t partitionProperties{ + UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, nullptr, + &propNextPart, 1}; + uint32_t numSubdevices = 0; + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + 0, nullptr, &numSubdevices)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + std::vector subDeviceHandles(numSubdevices); + auto pSubDevices = subDeviceHandles.data(); + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + numSubdevices, pSubDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + DeviceIdType subDeviceCount = 0; + std::transform(subDeviceHandles.cbegin(), subDeviceHandles.cend(), + std::back_inserter(subDevices), + [&](ur_device_handle_t urDeviceHandle) { + return DeviceSpec{ + DevicePartLevel::SUB, device.hwType, + device.rootId, subDeviceCount++, + DeviceIdTypeALL, urDeviceHandle}; + }); + return UR_RESULT_SUCCESS; + }); + + // To support sub-sub-device terms: + std::for_each( + subDevices.cbegin(), subDevices.cend(), [&](DeviceSpec device) { + ur_device_partition_property_t propNextPart{ + UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN, + {UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE}}; + ur_device_partition_properties_t partitionProperties{ + UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, nullptr, + &propNextPart, 1}; + uint32_t numSubSubdevices = 0; + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + 0, nullptr, &numSubSubdevices)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + std::vector subSubDeviceHandles( + numSubSubdevices); + auto pSubSubDevices = subSubDeviceHandles.data(); + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + numSubSubdevices, pSubSubDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + DeviceIdType subSubDeviceCount = 0; + std::transform( + subSubDeviceHandles.cbegin(), subSubDeviceHandles.cend(), + std::back_inserter(subSubDevices), + [&](ur_device_handle_t urDeviceHandle) { + return DeviceSpec{DevicePartLevel::SUBSUB, device.hwType, + device.rootId, device.subId, + subSubDeviceCount++, urDeviceHandle}; + }); + return UR_RESULT_SUCCESS; + }); + + auto ApplyFilter = [&](DeviceSpec &filter, DeviceSpec &device) -> bool { + bool matches = false; + if (filter.rootId == DeviceIdTypeALL) { + // if this is a root device filter, then it must be '*' or 'cpu' or 'gpu' or 'fpga' + // if this is a subdevice filter, then it must be '*.*' + // if this is a subsubdevice filter, then it must be '*.*.*' + matches = (filter.hwType == device.hwType) || + (filter.hwType == DeviceHardwareType::UR_DEVICE_TYPE_ALL); + logger::debug( + "DEBUG: In ApplyFilter, if block case 1, matches = {}", + matches); + } else if (filter.rootId != device.rootId) { + // root part in filter is a number but does not match the number in the root part of device + matches = false; + logger::debug("DEBUG: In ApplyFilter, if block case 2, matches = ", + matches); + } else if (filter.level == DevicePartLevel::ROOT) { + // this is a root device filter with a number that matches + matches = true; + logger::debug("DEBUG: In ApplyFilter, if block case 3, matches = ", + matches); + } else if (filter.subId == DeviceIdTypeALL) { + // sub type of star always matches (when root part matches, which we already know here) + // if this is a subdevice filter, then it must be 'matches.*' + // if this is a subsubdevice filter, then it must be 'matches.*.*' + matches = true; + logger::debug("DEBUG: In ApplyFilter, if block case 4, matches = ", + matches); + } else if (filter.subId != device.subId) { + // sub part in filter is a number but does not match the number in the sub part of device + matches = false; + logger::debug("DEBUG: In ApplyFilter, if block case 5, matches = ", + matches); + } else if (filter.level == DevicePartLevel::SUB) { + // this is a sub device number filter, numbers match in both parts + matches = true; + logger::debug("DEBUG: In ApplyFilter, if block case 6, matches = ", + matches); + } else if (filter.subsubId == DeviceIdTypeALL) { + // subsub type of star always matches (when other parts match, which we already know here) + // this is a subsub device filter, it must be 'matches.matches.*' + matches = true; + logger::debug("DEBUG: In ApplyFilter, if block case 7, matches = ", + matches); + } else { + // this is a subsub device filter, numbers in all three parts match + matches = (filter.subsubId == device.subsubId); + logger::debug("DEBUG: In ApplyFilter, if block case 8, matches = ", + matches); + } + return matches; + }; + + // apply each discard filter in turn by removing all matching elements + // from the appropriate device handle vector returned by the platform; + // no side-effect: the matching devices are just removed and discarded + for (auto &discard : discardDeviceList) { + auto ApplyDiscardFilter = [&](auto &device) -> bool { + return ApplyFilter(discard, device); + }; + if (discard.level == DevicePartLevel::ROOT) { + rootDevices.erase(std::remove_if(rootDevices.begin(), + rootDevices.end(), + ApplyDiscardFilter), + rootDevices.end()); + } + if (discard.level == DevicePartLevel::SUB) { + subDevices.erase(std::remove_if(subDevices.begin(), + subDevices.end(), + ApplyDiscardFilter), + subDevices.end()); + } + if (discard.level == DevicePartLevel::SUBSUB) { + subSubDevices.erase(std::remove_if(subSubDevices.begin(), + subSubDevices.end(), + ApplyDiscardFilter), + subSubDevices.end()); + } + } + + std::vector selectedDevices; + + // apply each accept filter in turn by removing all matching elements + // from the appropriate device handle vector returned by the platform + // but using a predicate with a side-effect that takes a copy of each + // of the accepted device handles just before they are removed + // removing each item as it is selected prevents us taking duplicates + // without needing O(n^2) de-duplicatation or symbolic simplification + for (auto &accept : acceptDeviceList) { + auto ApplyAcceptFilter = [&](auto &device) -> bool { + const bool matches = ApplyFilter(accept, device); + if (matches) { + selectedDevices.push_back(device.urDeviceHandle); + } + return matches; + }; + auto numAlreadySelected = selectedDevices.size(); + if (accept.level == DevicePartLevel::ROOT) { + rootDevices.erase(std::remove_if(rootDevices.begin(), + rootDevices.end(), + ApplyAcceptFilter), + rootDevices.end()); + } + if (accept.level == DevicePartLevel::SUB) { + subDevices.erase(std::remove_if(subDevices.begin(), + subDevices.end(), + ApplyAcceptFilter), + subDevices.end()); + } + if (accept.level == DevicePartLevel::SUBSUB) { + subSubDevices.erase(std::remove_if(subSubDevices.begin(), + subSubDevices.end(), + ApplyAcceptFilter), + subSubDevices.end()); + } + if (numAlreadySelected == selectedDevices.size()) { + logger::warning("WARNING: an accept term was ignored because it " + "does not select any additional devices" + "selectedDevices.size() = {}", + selectedDevices.size()); + } + } + + // selectedDevices is now a vector containing all the right device handles + + // should we return the size of the vector or the content of the vector? + if (NumEntries == 0) { + *pNumDevices = static_cast(selectedDevices.size()); + } else if (NumEntries > 0) { + size_t numToCopy = std::min((size_t)NumEntries, selectedDevices.size()); + std::copy_n(selectedDevices.cbegin(), numToCopy, phDevices); + if (pNumDevices != nullptr) { + *pNumDevices = static_cast(numToCopy); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + } + + return UR_RESULT_SUCCESS; +} } // namespace ur_lib diff --git a/source/loader/ur_lib.hpp b/source/loader/ur_lib.hpp index 41ab7cb52e..839c0041d9 100644 --- a/source/loader/ur_lib.hpp +++ b/source/loader/ur_lib.hpp @@ -105,5 +105,10 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, ur_code_location_callback_t pfnCodeloc, void *pUserData); +ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, + ur_device_type_t DeviceType, + uint32_t NumEntries, + ur_device_handle_t *phDevices, + uint32_t *pNumDevices); } // namespace ur_lib #endif /* UR_LOADER_LIB_H */ diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 1e9400aaa4..f8f81ea9d5 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -777,6 +777,52 @@ ur_result_t UR_APICALL urDeviceGet( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +ur_result_t UR_APICALL urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t + NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t * + phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. + ) try { + return ur_lib::urDeviceGetSelected(hPlatform, DeviceType, NumEntries, + phDevices, pNumDevices); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Retrieves various information about device /// diff --git a/source/loader/ur_print.cpp b/source/loader/ur_print.cpp index 1d8b3ca9af..1e8ad88086 100644 --- a/source/loader/ur_print.cpp +++ b/source/loader/ur_print.cpp @@ -2523,6 +2523,14 @@ ur_result_t urPrintDeviceGetParams(const struct ur_device_get_params_t *params, return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintDeviceGetSelectedParams( + const struct ur_device_get_selected_params_t *params, char *buffer, + const size_t buff_size, size_t *out_size) { + std::stringstream ss; + ss << params; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintDeviceGetInfoParams(const struct ur_device_get_info_params_t *params, char *buffer, const size_t buff_size, diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 5ee68ce529..6dcc7b4d56 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -673,6 +673,50 @@ ur_result_t UR_APICALL urDeviceGet( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +ur_result_t UR_APICALL urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t + NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t * + phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Retrieves various information about device /// diff --git a/test/conformance/CMakeLists.txt b/test/conformance/CMakeLists.txt index ac48f3a313..a03477bd1c 100644 --- a/test/conformance/CMakeLists.txt +++ b/test/conformance/CMakeLists.txt @@ -37,7 +37,8 @@ function(add_conformance_test name) ${PROJECT_NAME}::headers ${PROJECT_NAME}::testing ${PROJECT_NAME}::common - GTest::gtest_main) + GTest::gtest_main + unit_tests_helpers) if(UR_BUILD_ADAPTER_CUDA OR UR_BUILD_ADAPTER_ALL) add_test_adapter(${name} adapter_cuda) diff --git a/test/conformance/device/CMakeLists.txt b/test/conformance/device/CMakeLists.txt index 23ff5b4ebc..0f7da3d80c 100644 --- a/test/conformance/device/CMakeLists.txt +++ b/test/conformance/device/CMakeLists.txt @@ -9,6 +9,7 @@ add_conformance_test_with_platform_environment(device urDeviceGetGlobalTimestamps.cpp urDeviceGetInfo.cpp urDeviceGetNativeHandle.cpp + urDeviceGetSelected.cpp urDevicePartition.cpp urDeviceRelease.cpp urDeviceRetain.cpp diff --git a/test/conformance/device/urDeviceGetSelected.cpp b/test/conformance/device/urDeviceGetSelected.cpp new file mode 100644 index 0000000000..953b418e24 --- /dev/null +++ b/test/conformance/device/urDeviceGetSelected.cpp @@ -0,0 +1,249 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "helpers.h" +#include + +using urDeviceGetSelectedTest = uur::urPlatformTest; + +/* adpater agnostic tests -- none assume the existence or support of any specific adapter */ + +TEST_F(urDeviceGetSelectedTest, Success) { + unsetenv("ONEAPI_DEVICE_SELECTOR"); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto &device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSubsetOfDevices) { + unsetenv("ONEAPI_DEVICE_SELECTOR"); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + if (count < 2) { + GTEST_SKIP() << "There are fewer than two devices in total for the " + "platform so the subset test is impossible"; + } + std::vector devices(count - 1); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count - 1, + devices.data(), nullptr)); + for (auto device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSelected_StarColonStar) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:*", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto &device : devices) { + ASSERT_NE(nullptr, device); + } + + uint32_t countAll = 0; + ASSERT_SUCCESS( + urDeviceGet(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &countAll)); + ASSERT_NE(countAll, 0); + ASSERT_EQ(countAll, count); + std::vector devicesAll(countAll); + ASSERT_SUCCESS(urDeviceGet(platform, UR_DEVICE_TYPE_ALL, countAll, + devicesAll.data(), nullptr)); + for (auto &device : devicesAll) { + ASSERT_NE(nullptr, device); + } + + for (size_t i = 0; i < count; ++i) { + ASSERT_EQ(devices[i], devicesAll[i]); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSelected_StarColonZero) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto &device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSelected_StarColonZeroCommaStar) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0,*", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto &device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSelected_DiscardStarColonStar) { + setenv("ONEAPI_DEVICE_SELECTOR", "!*:*", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, SuccessSelected_SelectAndDiscard) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0;!*:*", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, + SuccessSelected_SelectSomethingAndDiscardSomethingElse) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0;!*:1", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto &device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, InvalidNullHandlePlatform) { + unsetenv("ONEAPI_DEVICE_SELECTOR"); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_NULL_HANDLE, + urDeviceGetSelected(nullptr, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); +} + +TEST_F(urDeviceGetSelectedTest, InvalidEnumerationDevicesType) { + unsetenv("ONEAPI_DEVICE_SELECTOR"); + uint32_t count = 0; + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_ENUMERATION, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_FORCE_UINT32, + 0, nullptr, &count)); +} + +TEST_F(urDeviceGetSelectedTest, InvalidValueNumEntries) { + unsetenv("ONEAPI_DEVICE_SELECTOR"); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, + devices.data(), nullptr)); +} + +TEST_F(urDeviceGetSelectedTest, InvalidMissingBackend) { + setenv("ONEAPI_DEVICE_SELECTOR", ":garbage", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_UNKNOWN, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidGarbageBackendString) { + setenv("ONEAPI_DEVICE_SELECTOR", "garbage:0", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidMissingFilterStrings) { + setenv("ONEAPI_DEVICE_SELECTOR", "*", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); + setenv("ONEAPI_DEVICE_SELECTOR", "*:", 1); + uint32_t count2 = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count2)); + ASSERT_EQ(count2, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidMissingFilterString) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0,,2", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_UNKNOWN, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidTooManyDotsInFilterString) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0.1.2.3", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidBadWildardInFilterString) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:*.", 1); + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); + setenv("ONEAPI_DEVICE_SELECTOR", "*:*.0", 1); + uint32_t count2 = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count2)); + ASSERT_EQ(count2, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidSelectingNonexistentDevice) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:4321", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidSelectingNonexistentSubDevice) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0.4321", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +} + +TEST_F(urDeviceGetSelectedTest, InvalidSelectingNonexistentSubSubDevice) { + setenv("ONEAPI_DEVICE_SELECTOR", "*:0.0.4321", 1); + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_EQ(count, 0); +}