Skip to content

Commit

Permalink
[UR][Tests] Add options in CTS to set test devices count or test device
Browse files Browse the repository at this point in the history
name and test platforms count or test platform name
  • Loading branch information
szadam committed Nov 29, 2023
1 parent 116fbe1 commit 36d6fde
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 29 deletions.
14 changes: 9 additions & 5 deletions test/conformance/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ When you fix any test, the match file must be updated
Empty match files indicate that there are no failing tests
in a particular group for the corresponding adapter.

## How to limit the test devices count
## How to set test device/platform name or limit the test devices/platforms count

To limit how many devices you want to run the CTS on,
use CMake option UR_TEST_DEVICES_COUNT. If you want to run
the tests on all available devices, set 0.
The default value is 1.
To limit how many devices/platforms you want to run the CTS on,
use CMake option UR_TEST_DEVICES_COUNT or
UR_TEST_PLATFORMS_COUNT. If you want to run the tests on
all available devices/platforms, set 0. The default value is 1.
If you run binaries for the tests, you can use the parameter
`--platforms_count=COUNT/--devices_count=COUNT`.
To set test device/platform name you want to run the CTS on, use
parameter `--platform=NAME/--device=NAME`.
130 changes: 106 additions & 24 deletions test/conformance/source/environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ std::ostream &operator<<(std::ostream &out,
return out;
}

std::ostream &operator<<(std::ostream &out, const ur_device_handle_t &device) {
size_t size;
urDeviceGetInfo(device, UR_DEVICE_INFO_NAME, 0, nullptr, &size);
std::vector<char> name(size);
urDeviceGetInfo(device, UR_DEVICE_INFO_NAME, size, name.data(), nullptr);
out << name.data();
return out;
}

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_device_handle_t> &devices) {
for (auto device : devices) {
out << "\n * \"" << device << "\"";
}
return out;
}

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_platform_handle_t> &platforms) {
for (auto platform : platforms) {
Expand Down Expand Up @@ -101,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
}

if (platform_options.platform_name.empty()) {
if (platforms.size() == 1) {

if (platforms.size() == 1 || platform_options.platforms_count == 1) {
platform = platforms[0];
} else {
std::stringstream ss_error;
ss_error << "Select a single platform from below using the "
"--platform=NAME "
"command-line option:"
<< platforms;
<< platforms << std::endl
<< "or set --platforms_count=1.";
error = ss_error.str();
return;
}
Expand Down Expand Up @@ -137,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
<< "\" not found. Select a single platform from below "
"using the "
"--platform=NAME command-line options:"
<< platforms;
<< platforms << std::endl
<< "or set --platforms_count=1.";
error = ss_error.str();
return;
}
Expand Down Expand Up @@ -179,6 +199,10 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
arg, "--platform=", sizeof("--platform=") - 1) == 0) {
options.platform_name =
std::string(&arg[std::strlen("--platform=")]);
} else if (std::strncmp(arg, "--platforms_count=",
sizeof("--platforms_count=") - 1) == 0) {
options.platforms_count = std::strtoul(
&arg[std::strlen("--platforms_count=")], nullptr, 10);
}
}

Expand All @@ -194,10 +218,31 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
return options;
}

DevicesEnvironment::DeviceOptions
DevicesEnvironment::parseDeviceOptions(int argc, char **argv) {
DeviceOptions options;
for (int argi = 1; argi < argc; ++argi) {
const char *arg = argv[argi];
if (!(std::strcmp(arg, "-h") && std::strcmp(arg, "--help"))) {
// TODO - print help
break;
} else if (std::strncmp(arg, "--device=", sizeof("--device=") - 1) ==
0) {
options.device_name = std::string(&arg[std::strlen("--device=")]);
} else if (std::strncmp(arg, "--devices_count=",
sizeof("--devices_count=") - 1) == 0) {
options.devices_count = std::strtoul(
&arg[std::strlen("--devices_count=")], nullptr, 10);
}
}
return options;
}

DevicesEnvironment *DevicesEnvironment::instance = nullptr;

DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
: PlatformEnvironment(argc, argv) {
: PlatformEnvironment(argc, argv),
device_options(parseDeviceOptions(argc, argv)) {
instance = this;
if (!error.empty()) {
return;
Expand All @@ -211,27 +256,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
error = "Could not find any devices associated with the platform";
return;
}
// Get the argument (test_devices_count) to limit test devices count.
u_long count_set = 0;
for (int i = 1; i < argc; ++i) {
if (std::strcmp(argv[i], "--test_devices_count") == 0 && i + 1 < argc) {
count_set = std::strtoul(argv[i + 1], nullptr, 10);
break;
}
}
// In case, the count_set is "0", the variable count will not be changed.

// Get the argument (devices_count) to limit test devices count.
// In case, the devices_count is "0", the variable count will not be changed.
// The CTS will run on all devices.
if (count_set > (std::numeric_limits<uint32_t>::max)()) {
error = "Invalid test_devices_count argument";
return;
} else if (count_set > 0) {
count = (std::min)(count, static_cast<uint32_t>(count_set));
}
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
if (device_options.device_name.empty()) {
if (device_options.devices_count >
(std::numeric_limits<uint32_t>::max)()) {
error = "Invalid devices_count argument";
return;
} else if (device_options.devices_count > 0) {
count = (std::min)(
count, static_cast<uint32_t>(device_options.devices_count));
}
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
} else {
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
for (u_long i = 0; i < count; i++) {
size_t size;
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, 0, nullptr,
&size)) {
error = "urDeviceGetInfo() failed";
return;
}
std::vector<char> device_name(size);
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, size,
device_name.data(), nullptr)) {
error = "urDeviceGetInfo() failed";
return;
}
if (device_options.device_name == device_name.data()) {
device = devices[i];
devices.clear();
devices.resize(1);
devices[0] = device;
break;
}
}
if (!device) {
std::stringstream ss_error;
ss_error << "Device \"" << device_options.device_name
<< "\" not found. Select a single device from below "
"using the "
"--device=NAME command-line options:"
<< devices << std::endl
<< "or set --devices_count=COUNT.";
error = ss_error.str();
return;
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions test/conformance/testing/include/uur/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PlatformEnvironment : ::testing::Environment {

struct PlatformOptions {
std::string platform_name;
unsigned long platforms_count;
};

PlatformEnvironment(int argc, char **argv);
Expand All @@ -36,17 +37,26 @@ struct PlatformEnvironment : ::testing::Environment {

struct DevicesEnvironment : PlatformEnvironment {

struct DeviceOptions {
std::string device_name;
unsigned long devices_count;
};

DevicesEnvironment(int argc, char **argv);
virtual ~DevicesEnvironment() override = default;

virtual void SetUp() override;
virtual void TearDown() override;

DeviceOptions parseDeviceOptions(int argc, char **argv);

inline const std::vector<ur_device_handle_t> &GetDevices() const {
return devices;
}

DeviceOptions device_options;
std::vector<ur_device_handle_t> devices;
ur_device_handle_t device = nullptr;
static DevicesEnvironment *instance;
};

Expand Down

0 comments on commit 36d6fde

Please sign in to comment.