Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceSanitizer] Handle the case of urMemGetNativeHandle getting a nullptr Device #1969

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions source/loader/layers/sanitizer/asan_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
return UR_RESULT_SUCCESS;
}

// Device may be null, we follow the L0 adapter's practice to use the first
// device
if (!Device) {
auto Devices = GetDevices(Context);
assert(Devices.size() > 0 && "Devices should not be empty");
Device = Devices[0];
}
assert((void *)Device != nullptr && "Device cannot be nullptr");

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto &Allocation = Allocations[Device];
ur_result_t URes = UR_RESULT_SUCCESS;
Expand Down
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ ur_result_t SanitizerInterceptor::updateShadowMemory(
ur_result_t
SanitizerInterceptor::registerDeviceGlobals(ur_context_handle_t Context,
ur_program_handle_t Program) {
std::vector<ur_device_handle_t> Devices = GetProgramDevices(Program);
std::vector<ur_device_handle_t> Devices = GetDevices(Program);

auto ContextInfo = getContextInfo(Context);

Expand Down
32 changes: 25 additions & 7 deletions source/loader/layers/sanitizer/ur_sanitizer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ ur_device_handle_t GetDevice(ur_queue_handle_t Queue) {
return Device;
}

std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context) {
yingcong-wu marked this conversation as resolved.
Show resolved Hide resolved
std::vector<ur_device_handle_t> Devices{};
uint32_t DeviceNum = 0;
[[maybe_unused]] ur_result_t Result;
Result = getContext()->urDdiTable.Context.pfnGetInfo(
Context, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(uint32_t), &DeviceNum,
nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
Devices.resize(DeviceNum);
Result = getContext()->urDdiTable.Context.pfnGetInfo(
Context, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * DeviceNum, Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
return Devices;
}

ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel) {
ur_program_handle_t Program{};
[[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnGetInfo(
Expand Down Expand Up @@ -169,18 +185,20 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
return (bool)Flag;
}

std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program) {
size_t PropSize;
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program) {
uint32_t DeviceNum = 0;
[[maybe_unused]] ur_result_t Result =
getContext()->urDdiTable.Program.pfnGetInfo(
Program, UR_PROGRAM_INFO_DEVICES, 0, nullptr, &PropSize);
assert(Result == UR_RESULT_SUCCESS);
Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(DeviceNum), &DeviceNum,
nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");

std::vector<ur_device_handle_t> Devices;
Devices.resize(PropSize / sizeof(ur_device_handle_t));
Devices.resize(DeviceNum);
Result = getContext()->urDdiTable.Program.pfnGetInfo(
Program, UR_PROGRAM_INFO_DEVICES, PropSize, Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS);
Program, UR_PROGRAM_INFO_DEVICES,
DeviceNum * sizeof(ur_device_handle_t), Devices.data(), nullptr);
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");

return Devices;
}
Expand Down
3 changes: 2 additions & 1 deletion source/loader/layers/sanitizer/ur_sanitizer_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ ur_context_handle_t GetContext(ur_queue_handle_t Queue);
ur_context_handle_t GetContext(ur_program_handle_t Program);
ur_context_handle_t GetContext(ur_kernel_handle_t Kernel);
ur_device_handle_t GetDevice(ur_queue_handle_t Queue);
std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context);
yingcong-wu marked this conversation as resolved.
Show resolved Hide resolved
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program);
DeviceType GetDeviceType(ur_context_handle_t Context,
ur_device_handle_t Device);
ur_device_handle_t GetParentDevice(ur_device_handle_t Device);
Expand All @@ -42,7 +44,6 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
std::string GetKernelName(ur_kernel_handle_t Kernel);
size_t GetDeviceLocalMemorySize(ur_device_handle_t Device);
ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel);
std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program);
ur_device_handle_t GetUSMAllocDevice(ur_context_handle_t Context,
const void *MemPtr);
uint32_t GetKernelNumArgs(ur_kernel_handle_t Kernel);
Expand Down
Loading