diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 503bae85..8853a9ff 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -100,6 +100,8 @@ namespace loader %endif if( ${X}_RESULT_SUCCESS != result ) break; + drv.driverInuse = true; + try { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako index 8b32a46b..3ba679c7 100644 --- a/scripts/templates/nullddi.cpp.mako +++ b/scripts/templates/nullddi.cpp.mako @@ -17,6 +17,7 @@ from templates import helper as th * */ #include "${x}_null.h" +#include namespace driver { @@ -46,6 +47,33 @@ namespace driver else { // generic implementation + %if re.match("Init", obj['name']): + %if re.match("InitDrivers", obj['name']): + auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); + if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) { + return ${X}_RESULT_ERROR_UNINITIALIZED; + } + } + if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) { + return ${X}_RESULT_ERROR_UNINITIALIZED; + } + } + %else: + auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); + if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) { + return ${X}_RESULT_ERROR_UNINITIALIZED; + } + } + if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) { + return ${X}_RESULT_ERROR_UNINITIALIZED; + } + } + %endif + %endif %for item in th.get_loader_epilogue(n, tags, obj, meta): %if 'range' in item: for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i ) diff --git a/source/drivers/null/ze_nullddi.cpp b/source/drivers/null/ze_nullddi.cpp index 130c7454..54ab0a70 100644 --- a/source/drivers/null/ze_nullddi.cpp +++ b/source/drivers/null/ze_nullddi.cpp @@ -8,6 +8,7 @@ * */ #include "ze_null.h" +#include namespace driver { @@ -30,6 +31,17 @@ namespace driver else { // generic implementation + auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); + if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } + if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } } return result; @@ -95,6 +107,17 @@ namespace driver else { // generic implementation + auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); + if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } + if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } for( size_t i = 0; ( nullptr != phDrivers ) && ( i < *pCount ); ++i ) phDrivers[ i ] = reinterpret_cast( context.get() ); diff --git a/source/drivers/null/zes_nullddi.cpp b/source/drivers/null/zes_nullddi.cpp index 5e8337c8..3717dd73 100644 --- a/source/drivers/null/zes_nullddi.cpp +++ b/source/drivers/null/zes_nullddi.cpp @@ -8,6 +8,7 @@ * */ #include "ze_null.h" +#include namespace driver { @@ -30,6 +31,17 @@ namespace driver else { // generic implementation + auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); + if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (!(flags & ZE_INIT_FLAG_GPU_ONLY)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } + if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (!(flags & ZE_INIT_FLAG_VPU_ONLY)) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + } } return result; diff --git a/source/drivers/null/zet_nullddi.cpp b/source/drivers/null/zet_nullddi.cpp index ef4757bb..b85ecab9 100644 --- a/source/drivers/null/zet_nullddi.cpp +++ b/source/drivers/null/zet_nullddi.cpp @@ -8,6 +8,7 @@ * */ #include "ze_null.h" +#include namespace driver { diff --git a/source/loader/ze_ldrddi.cpp b/source/loader/ze_ldrddi.cpp index 27832f13..7d1d5e65 100644 --- a/source/loader/ze_ldrddi.cpp +++ b/source/loader/ze_ldrddi.cpp @@ -82,6 +82,8 @@ namespace loader result = drv.dditable.ze.Driver.pfnGet( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ] ); if( ZE_RESULT_SUCCESS != result ) break; + drv.driverInuse = true; + try { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { @@ -154,6 +156,8 @@ namespace loader result = drv.dditable.ze.Global.pfnInitDrivers( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ], desc ); if( ZE_RESULT_SUCCESS != result ) break; + drv.driverInuse = true; + try { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index eafebe2d..8b957cce 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -171,16 +171,20 @@ namespace loader std::string errorMessage = "Check Drivers Failed on " + it->name + " , driver will be removed. " + initName + " failed with "; debug_trace_message(errorMessage, loader::to_string(result)); } - it = drivers->erase(it); - // If the number of drivers is now ==1, then we need to reinit the ddi tables to pass through. - // If ZE_ENABLE_LOADER_INTERCEPT is set to 1, then even if drivers were removed, don't reinit the ddi tables. - if (drivers->size() == 1 && !loader::context->forceIntercept) { - *requireDdiReinit = true; + // If the driver has already been init and handles are to be read, then this driver cannot be removed from the list. + if (!it->driverInuse) { + it = drivers->erase(it); + // If the number of drivers is now ==1, then we need to reinit the ddi tables to pass through. + // If ZE_ENABLE_LOADER_INTERCEPT is set to 1, then even if drivers were removed, don't reinit the ddi tables. + if (drivers->size() == 1 && !loader::context->forceIntercept) { + *requireDdiReinit = true; + } } if(return_first_driver_result) return result; - } - else { + } else { + // If this is a single driver system, then the first success for this driver needs to be set. + it->driverInuse = true; it++; } } diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index 445c26af..4ecdd33e 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -33,6 +33,7 @@ namespace loader ze_result_t initStatus = ZE_RESULT_SUCCESS; dditable_t dditable = {}; std::string name; + bool driverInuse = false; }; using driver_vector_t = std::vector< driver_t >; diff --git a/source/loader/zes_ldrddi.cpp b/source/loader/zes_ldrddi.cpp index d0b1ff85..bbc8862f 100644 --- a/source/loader/zes_ldrddi.cpp +++ b/source/loader/zes_ldrddi.cpp @@ -83,6 +83,8 @@ namespace loader result = drv.dditable.zes.Driver.pfnGet( &library_driver_handle_count, &phDrivers[ total_driver_handle_count ] ); if( ZE_RESULT_SUCCESS != result ) break; + drv.driverInuse = true; + try { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { diff --git a/test/loader_api.cpp b/test/loader_api.cpp index 4fdf8e19..fc5db002 100644 --- a/test/loader_api.cpp +++ b/test/loader_api.cpp @@ -11,6 +11,12 @@ #include "loader/ze_loader.h" #include "ze_api.h" +#if defined(_WIN32) + #define putenv_safe _putenv +#else + #define putenv_safe putenv +#endif + namespace { TEST( @@ -42,4 +48,176 @@ TEST( } } +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithGPUTypeThenExpectPassWithGPUorAllOnly) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = UINT32_MAX; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithNPUTypeThenExpectPassWithNPUorAllOnly) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = UINT32_MAX; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithAnyTypeWithNullDriverAcceptingAllThenExpectatLeast1Driver) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = UINT32_MAX; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU | ZE_INIT_DRIVER_TYPE_FLAG_NPU; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithAllTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithGPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_GPU_ONLY)); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithNPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_VPU_ONLY)); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithAllTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=ALL" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithGPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_GPU_ONLY)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithNPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) ); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_VPU_ONLY)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenOnlyOneSucceedsforGPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=GPU" ) ); + EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInit(ZE_INIT_FLAG_VPU_ONLY)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + +TEST( + LoaderAPI, + GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenOnlyOneSucceedsforNPUTypes) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_NULL_DRIVER_TYPE=NPU" ) ); + EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInit(ZE_INIT_FLAG_GPU_ONLY)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_GT(pCount, 0); +} + } // namespace