Skip to content

Commit

Permalink
Merge pull request #1667 from nrspruit/fix_multi_device_event_cache
Browse files Browse the repository at this point in the history
[UR] Fix Multi Device Event Cache for shared Root Device
  • Loading branch information
omarahmed1111 authored May 29, 2024
2 parents e18c691 + 0f2d1f4 commit c0c607c
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 4 deletions.
21 changes: 17 additions & 4 deletions source/adapters/level_zero/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,20 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(

std::shared_lock<ur_shared_mutex> Lock(EventList[I]->Mutex);

if (Queue && Queue->Device != CurQueueDevice &&
ur_device_handle_t QueueRootDevice;
ur_device_handle_t CurrentQueueRootDevice;
if (Queue) {
QueueRootDevice = Queue->Device;
CurrentQueueRootDevice = CurQueueDevice;
if (Queue->Device->isSubDevice()) {
QueueRootDevice = Queue->Device->RootDevice;
}
if (CurQueueDevice->isSubDevice()) {
CurrentQueueRootDevice = CurQueueDevice->RootDevice;
}
}

if (Queue && QueueRootDevice != CurrentQueueRootDevice &&
!EventList[I]->IsMultiDevice) {
ze_event_handle_t MultiDeviceZeEvent = nullptr;
ur_event_handle_t MultiDeviceEvent;
Expand All @@ -1514,10 +1527,10 @@ ur_result_t _ur_ze_event_list_t::createAndRetainUrZeEventList(
const auto &ZeCommandList = CommandList->first;
EventList[I]->RefCount.increment();

zeCommandListAppendWaitOnEvents(ZeCommandList, 1u,
&EventList[I]->ZeEvent);
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(ZeCommandList, 1u, &EventList[I]->ZeEvent));
if (!MultiDeviceEvent->CounterBasedEventsEnabled)
zeEventHostSignal(MultiDeviceZeEvent);
ZE2UR_CALL(zeEventHostSignal, (MultiDeviceZeEvent));

UR_CALL(Queue->executeCommandList(CommandList, /* IsBlocking */ false,
/* OkToBatchCommand */ true));
Expand Down
11 changes: 11 additions & 0 deletions test/adapters/level_zero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,15 @@ if(NOT WIN32)
)

target_link_libraries(test-adapter-level_zero_ze_calls PRIVATE zeCallMap)

add_adapter_test(level_zero_multi_queue
FIXTURE DEVICES
SOURCES
multi_device_event_cache_tests.cpp
ENVIRONMENT
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_level_zero>\""
"UR_L0_LEAKS_DEBUG=1"
)

target_link_libraries(test-adapter-level_zero_multi_queue PRIVATE zeCallMap)
endif()
107 changes: 107 additions & 0 deletions test/adapters/level_zero/multi_device_event_cache_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "ur_print.hpp"
#include "uur/fixtures.h"
#include "uur/raii.h"

#include <map>
#include <string>

extern std::map<std::string, int> *ZeCallCount;

using urMultiQueueMultiDeviceEventCacheTest = uur::urAllDevicesTest;
TEST_F(urMultiQueueMultiDeviceEventCacheTest,
GivenMultiSubDeviceWithQueuePerSubDeviceThenEventIsSharedBetweenQueues) {
uint32_t max_sub_devices = 0;
ASSERT_SUCCESS(
uur::GetDevicePartitionMaxSubDevices(devices[0], max_sub_devices));
if (max_sub_devices < 2) {
GTEST_SKIP();
}
ur_device_partition_property_t prop;
prop.type = UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN;
prop.value.affinity_domain =
UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE;

ur_device_partition_properties_t properties{
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
nullptr,
&prop,
1,
};
uint32_t numSubDevices = 0;
ASSERT_SUCCESS(
urDevicePartition(devices[0], &properties, 0, nullptr, &numSubDevices));
std::vector<ur_device_handle_t> sub_devices;
sub_devices.reserve(numSubDevices);
ASSERT_SUCCESS(urDevicePartition(devices[0], &properties, numSubDevices,
sub_devices.data(), nullptr));
uur::raii::Context context1 = nullptr;
ASSERT_SUCCESS(
urContextCreate(1, &sub_devices[0], nullptr, context1.ptr()));
ASSERT_NE(nullptr, context1);
uur::raii::Context context2 = nullptr;
ASSERT_SUCCESS(
urContextCreate(1, &sub_devices[1], nullptr, context2.ptr()));
ASSERT_NE(nullptr, context2);
ur_queue_handle_t queue1 = nullptr;
ASSERT_SUCCESS(urQueueCreate(context1, sub_devices[0], 0, &queue1));
ur_queue_handle_t queue2 = nullptr;
ASSERT_SUCCESS(urQueueCreate(context2, sub_devices[1], 0, &queue2));
uur::raii::Event event = nullptr;
uur::raii::Event eventWait = nullptr;
uur::raii::Event eventWaitDummy = nullptr;
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
eventWaitDummy.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue2, 1, eventWait.ptr(), event.ptr()));
EXPECT_EQ((*ZeCallCount)["zeCommandListAppendWaitOnEvents"], 2);
ASSERT_SUCCESS(urEventRelease(eventWaitDummy.get()));
ASSERT_SUCCESS(urEventRelease(eventWait.get()));
ASSERT_SUCCESS(urEventRelease(event.get()));
ASSERT_SUCCESS(urQueueRelease(queue2));
ASSERT_SUCCESS(urQueueRelease(queue1));
}

TEST_F(urMultiQueueMultiDeviceEventCacheTest,
GivenMultiDeviceWithQueuePerDeviceThenMultiDeviceEventIsCreated) {
if (devices.size() < 2) {
GTEST_SKIP();
}
uur::raii::Context context1 = nullptr;
ASSERT_SUCCESS(urContextCreate(1, &devices[0], nullptr, context1.ptr()));
ASSERT_NE(nullptr, context1);
uur::raii::Context context2 = nullptr;
ASSERT_SUCCESS(urContextCreate(1, &devices[1], nullptr, context2.ptr()));
ASSERT_NE(nullptr, context2);
ur_queue_handle_t queue1 = nullptr;
ASSERT_SUCCESS(urQueueCreate(context1, devices[0], 0, &queue1));
ur_queue_handle_t queue2 = nullptr;
ASSERT_SUCCESS(urQueueCreate(context2, devices[1], 0, &queue2));
uur::raii::Event event = nullptr;
uur::raii::Event eventWait = nullptr;
uur::raii::Event eventWaitDummy = nullptr;
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
eventWaitDummy.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue2, 1, eventWait.ptr(), event.ptr()));
EXPECT_EQ((*ZeCallCount)["zeCommandListAppendWaitOnEvents"], 3);
ASSERT_SUCCESS(urEventRelease(eventWaitDummy.get()));
ASSERT_SUCCESS(urEventRelease(eventWait.get()));
ASSERT_SUCCESS(urEventRelease(event.get()));
ASSERT_SUCCESS(urQueueRelease(queue2));
ASSERT_SUCCESS(urQueueRelease(queue1));
}

0 comments on commit c0c607c

Please sign in to comment.