Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][E2E] Fix infinite loop bug in Config/select_device.cpp #12814

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 47 additions & 75 deletions sycl/test-e2e/Config/select_device.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// REQUIRES: gpu
// Post-commit fails due to a bug in test, will fix in a couple of days.
// UNSUPPORTED: gpu-intel-dg2
// RUN: %{build} -o %t.out
//
// RUN: env ONEAPI_DEVICE_SELECTOR="*:gpu" %{run-unfiltered-devices} %t.out DEVICE_INFO write > %t.txt
Expand Down Expand Up @@ -86,92 +84,66 @@ static void addEscapeSymbolToSpecialCharacters(std::string &str) {
}
}

static std::vector<DevDescT> getAllowListDesc(std::string allowList) {
static std::vector<DevDescT> getAllowListDesc(std::string_view allowList) {
if (allowList.empty())
return {};

std::string deviceName("DeviceName:");
std::string driverVersion("DriverVersion:");
std::string platformName("PlatformName:");
std::string platformVersion("PlatformVersion:");
std::vector<DevDescT> decDescs;
decDescs.emplace_back();

size_t pos = 0;
while (pos < allowList.size()) {
if ((allowList.compare(pos, deviceName.size(), deviceName)) == 0) {
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw std::runtime_error("Malformed device allowlist");
}
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw std::runtime_error("Malformed device allowlist");
}
decDescs.back().devName = allowList.substr(start, pos - start);
pos = pos + 2;
auto try_parse = [&](std::string_view str) -> std::optional<std::string> {
// std::string_view::starts_with is C++20.
if (allowList.compare(0, str.size(), str) != 0)
return {};

if (allowList[pos] == ',') {
pos++;
}
}
allowList.remove_prefix(str.size());

else if ((allowList.compare(pos, driverVersion.size(), driverVersion)) ==
0) {
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw std::runtime_error("Malformed device allowlist");
}
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw std::runtime_error("Malformed device allowlist");
}
decDescs.back().devDriverVer = allowList.substr(start, pos - start);
pos = pos + 2;
using namespace std::string_literals;
auto pattern_start = allowList.find("{{");
if (pattern_start == std::string::npos)
throw std::runtime_error("Malformed "s + std::string{str} + " allowlist"s);

if (allowList[pos] == ',') {
pos++;
}
}
allowList.remove_prefix(pattern_start + 2);
auto pattern_end = allowList.find("}}");
if (pattern_end == std::string::npos)
throw std::runtime_error("Malformed "s + std::string{str} + " allowlist"s);

else if ((allowList.compare(pos, platformName.size(), platformName)) == 0) {
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw std::runtime_error("Malformed platform allowlist");
}
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw std::runtime_error("Malformed platform allowlist");
}
decDescs.back().platName = allowList.substr(start, pos - start);
pos = pos + 2;
if (allowList[pos] == ',') {
pos++;
}
}
auto result = allowList.substr(0, pattern_end);
allowList.remove_prefix(pattern_end + 2);

else if ((allowList.compare(pos, platformVersion.size(),
platformVersion)) == 0) {
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw std::runtime_error("Malformed platform allowlist");
}
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw std::runtime_error("Malformed platform allowlist");
}
decDescs.back().platVer = allowList.substr(start, pos - start);
pos = pos + 2;
}
if (allowList[0] == ',')
allowList.remove_prefix(1);
return {std::string{result}};
};

else if (allowList.find('|', pos) != std::string::npos) {
// FIXME: That is wrong and result in a infinite loop. We start processing
// the string from the start here.
pos = allowList.find('|') + 1;
while (allowList[pos] == ' ') {
pos++;
}
decDescs.emplace_back();
} else {
throw std::runtime_error("Malformed platform allowlist");
while (!allowList.empty()) {
if (auto pattern = try_parse("DeviceName:")) {
decDescs.back().devName = *pattern;
continue;
}
if (auto pattern = try_parse("DriverVersion:")) {
decDescs.back().devDriverVer = *pattern;
continue;
}
if (auto pattern = try_parse("PlatformName:")) {
decDescs.back().platName = *pattern;
continue;
}
if (auto pattern = try_parse("PlatformVersion:")) {
decDescs.back().platVer = *pattern;
continue;
}
} // while (pos <= allowList.size())

auto next = allowList.find('|');
if (next == std::string::npos)
throw std::runtime_error("Malformed allowlist");
allowList.remove_prefix(next + 1);

auto non_space = allowList.find_first_not_of(" ");
allowList.remove_prefix(non_space);
decDescs.emplace_back();
}

return decDescs;
}

Expand Down
Loading