diff --git a/source/loader/layers/sanitizer/asan_buffer.cpp b/source/loader/layers/sanitizer/asan_buffer.cpp index 896fee986e..9316d68bf4 100644 --- a/source/loader/layers/sanitizer/asan_buffer.cpp +++ b/source/loader/layers/sanitizer/asan_buffer.cpp @@ -75,6 +75,15 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { return UR_RESULT_SUCCESS; } + // Device may be null, we follow the L0 adapter's practice to use the first + // device + if (!Device) { + auto Devices = GetDevices(Context); + assert(Devices.size() > 0 && "Devices should not be empty"); + Device = Devices[0]; + } + assert((void *)Device != nullptr && "Device cannot be nullptr"); + std::scoped_lock Guard(Mutex); auto &Allocation = Allocations[Device]; ur_result_t URes = UR_RESULT_SUCCESS; diff --git a/source/loader/layers/sanitizer/asan_interceptor.cpp b/source/loader/layers/sanitizer/asan_interceptor.cpp index c4b8986b58..3abe7e79b4 100644 --- a/source/loader/layers/sanitizer/asan_interceptor.cpp +++ b/source/loader/layers/sanitizer/asan_interceptor.cpp @@ -552,7 +552,7 @@ ur_result_t SanitizerInterceptor::updateShadowMemory( ur_result_t SanitizerInterceptor::registerDeviceGlobals(ur_context_handle_t Context, ur_program_handle_t Program) { - std::vector Devices = GetProgramDevices(Program); + std::vector Devices = GetDevices(Program); auto ContextInfo = getContextInfo(Context); diff --git a/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp b/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp index feaff8757a..50cbdaf338 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp +++ b/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp @@ -72,6 +72,22 @@ ur_device_handle_t GetDevice(ur_queue_handle_t Queue) { return Device; } +std::vector GetDevices(ur_context_handle_t Context) { + std::vector Devices{}; + uint32_t DeviceNum = 0; + [[maybe_unused]] ur_result_t Result; + Result = getContext()->urDdiTable.Context.pfnGetInfo( + Context, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(uint32_t), &DeviceNum, + nullptr); + assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed"); + Devices.resize(DeviceNum); + Result = getContext()->urDdiTable.Context.pfnGetInfo( + Context, UR_CONTEXT_INFO_DEVICES, + sizeof(ur_device_handle_t) * DeviceNum, Devices.data(), nullptr); + assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed"); + return Devices; +} + ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel) { ur_program_handle_t Program{}; [[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnGetInfo( @@ -169,18 +185,20 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device, return (bool)Flag; } -std::vector GetProgramDevices(ur_program_handle_t Program) { - size_t PropSize; +std::vector GetDevices(ur_program_handle_t Program) { + uint32_t DeviceNum = 0; [[maybe_unused]] ur_result_t Result = getContext()->urDdiTable.Program.pfnGetInfo( - Program, UR_PROGRAM_INFO_DEVICES, 0, nullptr, &PropSize); - assert(Result == UR_RESULT_SUCCESS); + Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(DeviceNum), &DeviceNum, + nullptr); + assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed"); std::vector Devices; - Devices.resize(PropSize / sizeof(ur_device_handle_t)); + Devices.resize(DeviceNum); Result = getContext()->urDdiTable.Program.pfnGetInfo( - Program, UR_PROGRAM_INFO_DEVICES, PropSize, Devices.data(), nullptr); - assert(Result == UR_RESULT_SUCCESS); + Program, UR_PROGRAM_INFO_DEVICES, + DeviceNum * sizeof(ur_device_handle_t), Devices.data(), nullptr); + assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed"); return Devices; } diff --git a/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp b/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp index 44ddf46922..479ecd635b 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp +++ b/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp @@ -34,6 +34,8 @@ ur_context_handle_t GetContext(ur_queue_handle_t Queue); ur_context_handle_t GetContext(ur_program_handle_t Program); ur_context_handle_t GetContext(ur_kernel_handle_t Kernel); ur_device_handle_t GetDevice(ur_queue_handle_t Queue); +std::vector GetDevices(ur_context_handle_t Context); +std::vector GetDevices(ur_program_handle_t Program); DeviceType GetDeviceType(ur_context_handle_t Context, ur_device_handle_t Device); ur_device_handle_t GetParentDevice(ur_device_handle_t Device); @@ -42,7 +44,6 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device, std::string GetKernelName(ur_kernel_handle_t Kernel); size_t GetDeviceLocalMemorySize(ur_device_handle_t Device); ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel); -std::vector GetProgramDevices(ur_program_handle_t Program); ur_device_handle_t GetUSMAllocDevice(ur_context_handle_t Context, const void *MemPtr); uint32_t GetKernelNumArgs(ur_kernel_handle_t Kernel);