Skip to content

Commit

Permalink
[Loader] Refactor ONEAPI_DEVICE_SELECTOR, add tests
Browse files Browse the repository at this point in the history
This patch resolves a number of issues with the extant implementation of
the `ONEAPI_DEVICE_SELECTOR`:

* Filtering only being applied to devices within a platform, resulting
  in an incorrect list of devices being returned in a multi-platform,
  multi-device context.
* Device indices being counted only inside a platform, not globally,
  resulting in multiple devices with 0 index in the presence of
  multiple platforms.
* Lack of testing for non-hardware dependent configurations, e.g. unable
  to emulate filtering against hypothetical sets of devices.
* Parsing the ONEAPI_DEVICE_SELECTOR string each time
  `urDeviceGetSelected()` was called.

This new implementation was written with testing of the individaul terms
in the ONEAPI_DEVICE_SELECTOR BNF grammer as a design requirement. It
decouples the enumeration of devices from the filtering logic. Both
parsing and device enumeration are moved much earlier, into loader
initialization, and performed only once.
  • Loading branch information
kbenzie committed Apr 3, 2024
1 parent a0f3c51 commit aa914df
Show file tree
Hide file tree
Showing 21 changed files with 1,564 additions and 602 deletions.
5 changes: 1 addition & 4 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(
std::copy_n(MatchedDevices.begin(), N, Devices);

if (NumDevices) {
if (*NumDevices == 0)
*NumDevices = ZeDeviceCount;
else
*NumDevices = N;
*NumDevices = ZeDeviceCount;
}

return UR_RESULT_SUCCESS;
Expand Down
2 changes: 2 additions & 0 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# See LICENSE.TXT
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(device_selector)

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/UrLoaderVersion.rc.in
${CMAKE_CURRENT_BINARY_DIR}/UrLoaderVersion.rc
Expand Down
17 changes: 17 additions & 0 deletions source/loader/device_selector/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (C) 2022 Intel Corporation
# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
# See LICENSE.TXT
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_library(ur_device_selector INTERFACE)

add_library(${PROJECT_NAME}::device_selector ALIAS ur_device_selector)

target_include_directories(ur_device_selector INTERFACE
${CMAKE_CURRENT_SOURCE_DIR}/..
)

target_link_libraries(ur_device_selector INTERFACE
${PROJECT_NAME}::headers
${PROJECT_NAME}::common
)
89 changes: 89 additions & 0 deletions source/loader/device_selector/backend.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include "descriptor.hpp"
#include "ur_api.h"
#include <algorithm>
#include <optional>
#include <string>
#include <vector>

namespace ur::device_selector {

struct BackendMatcher {
/// @brief Initilalize the backend matcher from a string in a case
/// insensitive way.
/// @param[in] backend Backend filter string.
/// @return Returns a diagnostic if an invalid backend is found.
std::optional<Diagnostic> init(std::string_view str) {
if (str.empty()) {
return Diagnostic("empty backend");
}
std::string lower(str);
std::transform(lower.cbegin(), lower.cend(), lower.begin(),
[](char c) { return std::tolower(c); });
if (lower == "*") {
matches.insert(matches.begin(), {
UR_PLATFORM_BACKEND_UNKNOWN,
UR_PLATFORM_BACKEND_LEVEL_ZERO,
UR_PLATFORM_BACKEND_OPENCL,
UR_PLATFORM_BACKEND_CUDA,
UR_PLATFORM_BACKEND_HIP,
UR_PLATFORM_BACKEND_OPENCL,
UR_PLATFORM_BACKEND_NATIVE_CPU,
});
} else if (lower == "opencl") {
matches.push_back(UR_PLATFORM_BACKEND_OPENCL);
} else if (lower == "level_zero" || lower == "ext_oneapi_level_zero") {
matches.push_back(UR_PLATFORM_BACKEND_LEVEL_ZERO);
} else if (lower == "cuda" || lower == "ext_oneapi_cuda") {
matches.push_back(UR_PLATFORM_BACKEND_CUDA);
} else if (lower == "hip" || lower == "ext_oneapi_hip") {
matches.push_back(UR_PLATFORM_BACKEND_HIP);
} else if (lower == "native_cpu") {
matches.push_back(UR_PLATFORM_BACKEND_NATIVE_CPU);
} else {
std::string diagnostic = "invalid backend: '";
diagnostic.append(str);
return Diagnostic(diagnostic + "'");
}
return std::nullopt;
}

friend bool operator==(const BackendMatcher &matcher,
const Descriptor &descriptor) {
return std::any_of(
matcher.matches.begin(), matcher.matches.end(),
[&descriptor](const ur_platform_backend_t &backendMatch) {
return backendMatch == descriptor.backend;
});
}

friend bool operator==(const Descriptor &backend,
const BackendMatcher &matcher) {
return matcher == backend;
}

friend bool operator!=(const BackendMatcher &matcher,
const Descriptor &descriptor) {
return std::none_of(
matcher.matches.begin(), matcher.matches.end(),
[&descriptor](const ur_platform_backend_t &backendMatch) {
return backendMatch == descriptor.backend;
});
}

friend bool operator!=(const Descriptor &backend,
const BackendMatcher &matcher) {
return matcher != backend;
}

private:
std::vector<ur_platform_backend_t> matches;
};

} // namespace ur::device_selector
46 changes: 46 additions & 0 deletions source/loader/device_selector/descriptor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include "ur_api.h"
#include <optional>
#include <string>

namespace ur::device_selector {

struct Diagnostic {
Diagnostic(std::string message) : message(message) {}
std::string message;
};

struct Descriptor {
Descriptor() = default;

Descriptor(const Descriptor &other)
: backend(other.backend), type(other.type), index(other.index),
subIndex(other.subIndex), subSubIndex(other.subSubIndex) {}

Descriptor(ur_platform_backend_t backend, ur_device_type_t type,
uint32_t index)
: backend(backend), type(type), index(index) {}
Descriptor(ur_platform_backend_t backend, ur_device_type_t type,
uint32_t index, uint32_t subDeviceIndex)
: backend(backend), type(type), index(index), subIndex(subDeviceIndex) {
}
Descriptor(ur_platform_backend_t backend, ur_device_type_t type,
uint32_t index, uint32_t subDeviceIndex,
uint32_t subSubDeviceIndex)
: backend(backend), type(type), index(index), subIndex(subDeviceIndex),
subSubIndex(subSubDeviceIndex) {}

ur_platform_backend_t backend;
ur_device_type_t type;
uint32_t index = 0;
std::optional<uint32_t> subIndex;
std::optional<uint32_t> subSubIndex;
};

} // namespace ur::device_selector
Loading

0 comments on commit aa914df

Please sign in to comment.