Skip to content

Commit

Permalink
[UR][Tests] Add options in CTS to set test devices count or test device
Browse files Browse the repository at this point in the history
name and test platforms count or test platform name
  • Loading branch information
szadam committed Nov 28, 2023
1 parent 534071e commit a2b720b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 9 deletions.
118 changes: 109 additions & 9 deletions test/conformance/source/environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cstring>
#include <fstream>

Expand All @@ -12,6 +13,7 @@
#include "kernel_entry_points.h"
#endif

#include <ur_util.hpp>
#include <uur/environment.h>
#include <uur/utils.h>

Expand All @@ -32,6 +34,23 @@ std::ostream &operator<<(std::ostream &out,
return out;
}

std::ostream &operator<<(std::ostream &out, const ur_device_handle_t &device) {
size_t size;
urDeviceGetInfo(device, UR_DEVICE_INFO_NAME, 0, nullptr, &size);
std::vector<char> name(size);
urDeviceGetInfo(device, UR_DEVICE_INFO_NAME, size, name.data(), nullptr);
out << name.data();
return out;
}

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_device_handle_t> &devices) {
for (auto device : devices) {
out << "\n * \"" << device << "\"";
}
return out;
}

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_platform_handle_t> &platforms) {
for (auto platform : platforms) {
Expand Down Expand Up @@ -99,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
}

if (platform_options.platform_name.empty()) {
if (platforms.size() == 1) {

if (platforms.size() == 1 || platform_options.platforms_count == 1) {
platform = platforms[0];
} else {
std::stringstream ss_error;
ss_error << "Select a single platform from below using the "
"--platform=NAME "
"command-line option:"
<< platforms;
<< platforms << std::endl
<< "or set --platforms_count=1.";
error = ss_error.str();
return;
}
Expand Down Expand Up @@ -135,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
<< "\" not found. Select a single platform from below "
"using the "
"--platform=NAME command-line options:"
<< platforms;
<< platforms << std::endl
<< "or set --platforms_count=1.";
error = ss_error.str();
return;
}
Expand Down Expand Up @@ -177,6 +199,30 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
arg, "--platform=", sizeof("--platform=") - 1) == 0) {
options.platform_name =
std::string(&arg[std::strlen("--platform=")]);
} else if (std::strncmp(arg, "--platforms_count=",
sizeof("--platforms_count=") - 1) == 0) {
options.platforms_count = std::strtoul(
&arg[std::strlen("--platforms_count=")], nullptr, 10);
}
}
return options;
}

DevicesEnvironment::DeviceOptions
DevicesEnvironment::parseDeviceOptions(int argc, char **argv) {
DeviceOptions options;
for (int argi = 1; argi < argc; ++argi) {
const char *arg = argv[argi];
if (!(std::strcmp(arg, "-h") && std::strcmp(arg, "--help"))) {
// TODO - print help
break;
} else if (std::strncmp(arg, "--device=", sizeof("--device=") - 1) ==
0) {
options.device_name = std::string(&arg[std::strlen("--device=")]);
} else if (std::strncmp(arg, "--devices_count=",
sizeof("--devices_count=") - 1) == 0) {
options.devices_count = std::strtoul(
&arg[std::strlen("--devices_count=")], nullptr, 10);
}
}
return options;
Expand All @@ -185,7 +231,8 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
DevicesEnvironment *DevicesEnvironment::instance = nullptr;

DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
: PlatformEnvironment(argc, argv) {
: PlatformEnvironment(argc, argv),
device_options(parseDeviceOptions(argc, argv)) {
instance = this;
if (!error.empty()) {
return;
Expand All @@ -199,11 +246,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
error = "Could not find any devices associated with the platform";
return;
}
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;

// Get the argument (devices_count) to limit test devices count.
// In case, the devices_count is "0", the variable count will not be changed.
// The CTS will run on all devices.
if (device_options.device_name.empty()) {
if (device_options.devices_count >
(std::numeric_limits<uint32_t>::max)()) {
error = "Invalid devices_count argument";
return;
} else if (device_options.devices_count > 0) {
count = (std::min)(
count, static_cast<uint32_t>(device_options.devices_count));
}
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
} else {
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
for (u_long i = 0; i < count; i++) {
size_t size;
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, 0, nullptr,
&size)) {
error = "urDeviceGetInfo() failed";
return;
}
std::vector<char> device_name(size);
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, size,
device_name.data(), nullptr)) {
error = "urDeviceGetInfo() failed";
return;
}
if (device_options.device_name == device_name.data()) {
device = devices[i];
devices.clear();
devices.resize(1);
devices[0] = device;
break;
}
}
if (!device) {
std::stringstream ss_error;
ss_error << "Device \"" << device_options.device_name
<< "\" not found. Select a single device from below "
"using the "
"--device=NAME command-line options:"
<< devices << std::endl
<< "or set --devices_count=COUNT.";
error = ss_error.str();
return;
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions test/conformance/testing/include/uur/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PlatformEnvironment : ::testing::Environment {

struct PlatformOptions {
std::string platform_name;
u_long platforms_count;
};

PlatformEnvironment(int argc, char **argv);
Expand All @@ -36,17 +37,26 @@ struct PlatformEnvironment : ::testing::Environment {

struct DevicesEnvironment : PlatformEnvironment {

struct DeviceOptions {
std::string device_name;
u_long devices_count;
};

DevicesEnvironment(int argc, char **argv);
virtual ~DevicesEnvironment() override = default;

virtual void SetUp() override;
virtual void TearDown() override;

DeviceOptions parseDeviceOptions(int argc, char **argv);

inline const std::vector<ur_device_handle_t> &GetDevices() const {
return devices;
}

DeviceOptions device_options;
std::vector<ur_device_handle_t> devices;
ur_device_handle_t device = nullptr;
static DevicesEnvironment *instance;
};

Expand Down

0 comments on commit a2b720b

Please sign in to comment.