diff --git a/loader/icd.c b/loader/icd.c index 1cb950b3..e50b27a2 100644 --- a/loader/icd.c +++ b/loader/icd.c @@ -60,6 +60,7 @@ void khrIcdVendorAdd(const char *libraryName) pfn_clIcdGetPlatformIDs p_clIcdGetPlatformIDs = NULL; #if KHR_LOADER_MANAGED_DISPATCH clGetFunctionAddressForPlatformKHR_fn p_clGetFunctionAddressForPlatform = NULL; + clSetPlatformDispatchDataKHR_fn p_clSetPlatformDispatchData = NULL; #endif cl_uint i = 0; cl_uint platformCount = 0; @@ -108,8 +109,9 @@ void khrIcdVendorAdd(const char *libraryName) } #if KHR_LOADER_MANAGED_DISPATCH - // try to get clGetFunctionAddressForPlatformKHR to detect cl_khr_icd2 support + // try to get clGetFunctionAddressForPlatformKHR and clSetPlatformDispatchDataKHR to detect cl_khr_icd2 support p_clGetFunctionAddressForPlatform = (clGetFunctionAddressForPlatformKHR_fn)(size_t)p_clGetExtensionFunctionAddress("clGetFunctionAddressForPlatformKHR"); + p_clSetPlatformDispatchData = (clSetPlatformDispatchDataKHR_fn)(size_t)p_clGetExtensionFunctionAddress("clSetPlatformDispatchDataKHR"); #endif // query the number of platforms available and allocate space to store them @@ -152,6 +154,11 @@ void khrIcdVendorAdd(const char *libraryName) KHR_ICD_TRACE("found icd 2 object, but platform is missing clGetFunctionAddressForPlatformKHR"); continue; } + if (KHR_ICD2_HAS_TAG(platforms[i]) && !p_clSetPlatformDispatchData) + { + KHR_ICD_TRACE("found icd 2 object, but platform is missing clSetPlatformDispatchDataKHR"); + continue; + } #endif // allocate a structure for the vendor @@ -165,10 +172,10 @@ void khrIcdVendorAdd(const char *libraryName) #if KHR_LOADER_MANAGED_DISPATCH // populate cl_khr_icd2 platform's loader managed dispatch tables - if (p_clGetFunctionAddressForPlatform && KHR_ICD2_HAS_TAG(platforms[i])) + if (KHR_ICD2_HAS_TAG(platforms[i])) { khrIcd2PopulateDispatchTable(platforms[i], p_clGetFunctionAddressForPlatform, &vendor->dispData.dispatch); - platforms[i]->dispData = &vendor->dispData; + p_clSetPlatformDispatchData(platforms[i], &vendor->dispData); KHR_ICD_TRACE("found icd 2 platform, using loader managed dispatch\n"); } #endif diff --git a/loader/icd_dispatch.h b/loader/icd_dispatch.h index 36e0d5b9..58107486 100644 --- a/loader/icd_dispatch.h +++ b/loader/icd_dispatch.h @@ -77,6 +77,14 @@ clGetFunctionAddressForPlatformKHR_t( typedef clGetFunctionAddressForPlatformKHR_t * clGetFunctionAddressForPlatformKHR_fn; +typedef int CL_API_CALL +clSetPlatformDispatchDataKHR_t( + cl_platform_id platform, + void *disp_data); + +typedef clSetPlatformDispatchDataKHR_t * +clSetPlatformDispatchDataKHR_fn; + extern void khrIcd2PopulateDispatchTable( cl_platform_id platform, clGetFunctionAddressForPlatformKHR_fn p_clGetFunctionAddressForPlatform,