Skip to content

Commit

Permalink
Add driver init check and init tests
Browse files Browse the repository at this point in the history
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Sep 12, 2024
1 parent 528faad commit 77710c8
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 7 deletions.
2 changes: 2 additions & 0 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ namespace loader
%endif
if( ${X}_RESULT_SUCCESS != result ) break;

drv.driverInuse = true;

try
{
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
Expand Down
28 changes: 28 additions & 0 deletions scripts/templates/nullddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from templates import helper as th
*
*/
#include "${x}_null.h"
#include <cstring>

namespace driver
{
Expand Down Expand Up @@ -46,6 +47,33 @@ namespace driver
else
{
// generic implementation
%if re.match("Init", obj['name']):
%if re.match("InitDrivers", obj['name']):
auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" );
if (std::strcmp(driver_type.c_str(), "GPU") == 0) {
if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}
}
if (std::strcmp(driver_type.c_str(), "NPU") == 0) {
if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}
}
%else:
auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" );
if (std::strcmp(driver_type.c_str(), "GPU") == 0) {
if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}
}
if (std::strcmp(driver_type.c_str(), "NPU") == 0) {
if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}
}
%endif
%endif
%for item in th.get_loader_epilogue(n, tags, obj, meta):
%if 'range' in item:
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
Expand Down
23 changes: 23 additions & 0 deletions source/drivers/null/ze_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*
*/
#include "ze_null.h"
#include <cstring>

namespace driver
{
Expand All @@ -30,6 +31,17 @@ namespace driver
else
{
// generic implementation
auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" );
if (std::strcmp(driver_type.c_str(), "GPU") == 0) {
if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
if (std::strcmp(driver_type.c_str(), "NPU") == 0) {
if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
}

return result;
Expand Down Expand Up @@ -95,6 +107,17 @@ namespace driver
else
{
// generic implementation
auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" );
if (std::strcmp(driver_type.c_str(), "GPU") == 0) {
if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
if (std::strcmp(driver_type.c_str(), "NPU") == 0) {
if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
for( size_t i = 0; ( nullptr != phDrivers ) && ( i < *pCount ); ++i )
phDrivers[ i ] = reinterpret_cast<ze_driver_handle_t>( context.get() );

Expand Down
12 changes: 12 additions & 0 deletions source/drivers/null/zes_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*
*/
#include "ze_null.h"
#include <cstring>

namespace driver
{
Expand All @@ -30,6 +31,17 @@ namespace driver
else
{
// generic implementation
auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" );
if (std::strcmp(driver_type.c_str(), "GPU") == 0) {
if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
if (std::strcmp(driver_type.c_str(), "NPU") == 0) {
if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}
}

return result;
Expand Down
1 change: 1 addition & 0 deletions source/drivers/null/zet_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*
*/
#include "ze_null.h"
#include <cstring>

namespace driver
{
Expand Down
4 changes: 4 additions & 0 deletions source/loader/ze_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ namespace loader
result = drv.dditable.ze.Driver.pfnGet( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ] );
if( ZE_RESULT_SUCCESS != result ) break;

drv.driverInuse = true;

try
{
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
Expand Down Expand Up @@ -154,6 +156,8 @@ namespace loader
result = drv.dditable.ze.Global.pfnInitDrivers( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ], desc );
if( ZE_RESULT_SUCCESS != result ) break;

drv.driverInuse = true;

try
{
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
Expand Down
18 changes: 11 additions & 7 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,20 @@ namespace loader
std::string errorMessage = "Check Drivers Failed on " + it->name + " , driver will be removed. " + initName + " failed with ";
debug_trace_message(errorMessage, loader::to_string(result));
}
it = drivers->erase(it);
// If the number of drivers is now ==1, then we need to reinit the ddi tables to pass through.
// If ZE_ENABLE_LOADER_INTERCEPT is set to 1, then even if drivers were removed, don't reinit the ddi tables.
if (drivers->size() == 1 && !loader::context->forceIntercept) {
*requireDdiReinit = true;
// If the driver has already been init and handles are to be read, then this driver cannot be removed from the list.
if (!it->driverInuse) {
it = drivers->erase(it);
// If the number of drivers is now ==1, then we need to reinit the ddi tables to pass through.
// If ZE_ENABLE_LOADER_INTERCEPT is set to 1, then even if drivers were removed, don't reinit the ddi tables.
if (drivers->size() == 1 && !loader::context->forceIntercept) {
*requireDdiReinit = true;
}
}
if(return_first_driver_result)
return result;
}
else {
} else {
// If this is a single driver system, then the first success for this driver needs to be set.
it->driverInuse = true;
it++;
}
}
Expand Down
1 change: 1 addition & 0 deletions source/loader/ze_loader_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace loader
ze_result_t initStatus = ZE_RESULT_SUCCESS;
dditable_t dditable = {};
std::string name;
bool driverInuse = false;
};

using driver_vector_t = std::vector< driver_t >;
Expand Down
2 changes: 2 additions & 0 deletions source/loader/zes_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ namespace loader
result = drv.dditable.zes.Driver.pfnGet( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ] );
if( ZE_RESULT_SUCCESS != result ) break;

drv.driverInuse = true;

try
{
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
Expand Down
178 changes: 178 additions & 0 deletions test/loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
#include "loader/ze_loader.h"
#include "ze_api.h"

#if defined(_WIN32)
#define putenv_safe _putenv
#else
#define putenv_safe putenv
#endif

namespace {

TEST(
Expand Down Expand Up @@ -42,4 +48,176 @@ TEST(
}
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithGPUTypeThenExpectPassWithGPUorAllOnly) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = UINT32_MAX;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithNPUTypeThenExpectPassWithNPUorAllOnly) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = UINT32_MAX;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithAnyTypeWithNullDriverAcceptingAllThenExpectatLeast1Driver) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = UINT32_MAX;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithAllTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithGPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_GPU_ONLY));
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithNPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_VPU_ONLY));
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithAllTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithGPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_GPU_ONLY));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithNPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) );
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_VPU_ONLY));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenOnlyOneSucceedsforGPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) );
EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInit(ZE_INIT_FLAG_VPU_ONLY));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

TEST(
LoaderAPI,
GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenOnlyOneSucceedsforNPUTypes) {

uint32_t pCount = 0;
ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
desc.flags = UINT32_MAX;
desc.pNext = nullptr;
putenv_safe( const_cast<char *>( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) );
EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInit(ZE_INIT_FLAG_GPU_ONLY));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc));
EXPECT_GT(pCount, 0);
}

} // namespace

0 comments on commit 77710c8

Please sign in to comment.