Skip to content

Commit

Permalink
[SYCL] Protect access to the native handle of a sycl::event (#15179)
Browse files Browse the repository at this point in the history
Fix for #14623

Currently event_impl exposes reference to the underlying UR handle.
As a result this handle can be updated/read at the random moments of
time by different threads causing data race.
This PR removes methods which expose the reference and replace them with
thread-safe getter/setter.
  • Loading branch information
againull committed Aug 28, 2024
1 parent e374c69 commit a689b8d
Show file tree
Hide file tree
Showing 19 changed files with 294 additions and 182 deletions.
81 changes: 48 additions & 33 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@ void event_impl::initContextIfNeeded() {

event_impl::~event_impl() {
try {
if (MEvent)
getPlugin()->call(urEventRelease, MEvent);
auto Handle = this->getHandle();
if (Handle)
getPlugin()->call(urEventRelease, Handle);
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~event_impl", e);
}
}

void event_impl::waitInternal(bool *Success) {
if (!MIsHostEvent && MEvent) {
auto Handle = this->getHandle();
if (!MIsHostEvent && Handle) {
// Wait for the native event
ur_result_t Err = getPlugin()->call_nocheck(urEventWait, 1, &MEvent);
ur_result_t Err = getPlugin()->call_nocheck(urEventWait, 1, &Handle);
// TODO drop the UR_RESULT_ERROR_UKNOWN from here (this was waiting for
// https://github.com/oneapi-src/unified-runtime/issues/1459 which is now
// closed).
Expand Down Expand Up @@ -89,7 +91,7 @@ void event_impl::waitInternal(bool *Success) {
}

void event_impl::setComplete() {
if (MIsHostEvent || !MEvent) {
if (MIsHostEvent || !this->getHandle()) {
{
std::unique_lock<std::mutex> lock(MMutex);
#ifndef NDEBUG
Expand All @@ -116,8 +118,11 @@ static uint64_t inline getTimestamp() {
.count();
}

const ur_event_handle_t &event_impl::getHandleRef() const { return MEvent; }
ur_event_handle_t &event_impl::getHandleRef() { return MEvent; }
ur_event_handle_t event_impl::getHandle() const { return MEvent.load(); }

void event_impl::setHandle(const ur_event_handle_t &UREvent) {
MEvent.store(UREvent);
}

const ContextImplPtr &event_impl::getContextImpl() {
initContextIfNeeded();
Expand All @@ -141,7 +146,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext)
MIsFlushed(true), MState(HES_Complete) {

ur_context_handle_t TempContext;
getPlugin()->call(urEventGetInfo, MEvent, UR_EVENT_INFO_CONTEXT,
getPlugin()->call(urEventGetInfo, this->getHandle(), UR_EVENT_INFO_CONTEXT,
sizeof(ur_context_handle_t), &TempContext, nullptr);

if (MContext->getHandleRef() != TempContext) {
Expand Down Expand Up @@ -183,7 +188,7 @@ void *event_impl::instrumentationProlog(std::string &Name, int32_t StreamID,
// Create a string with the event address so it
// can be associated with other debug data
xpti::utils::StringHelper SH;
Name = SH.nameWithAddress<ur_event_handle_t>("event.wait", MEvent);
Name = SH.nameWithAddress<ur_event_handle_t>("event.wait", this->getHandle());

// We can emit the wait associated with the graph if the
// event does not have a command object or associated with
Expand Down Expand Up @@ -249,9 +254,10 @@ void event_impl::wait(std::shared_ptr<sycl::detail::event_impl> Self,
TelemetryEvent = instrumentationProlog(Name, StreamID, IId);
#endif

if (MEvent)
// presence of MEvent means the command has been enqueued, so no need to
// go via the slow path event waiting in the scheduler
auto EventHandle = getHandle();
if (EventHandle)
// presence of the native handle means the command has been enqueued, so no
// need to go via the slow path event waiting in the scheduler
waitInternal(Success);
else if (MCommand)
detail::Scheduler::getInstance().waitForEvent(Self, Success);
Expand Down Expand Up @@ -294,7 +300,7 @@ event_impl::get_profiling_info<info::event_profiling::command_submit>() {
// For profiling tag events we rely on the submission time reported as
// the start time has undefined behavior.
return get_event_profiling_info<info::event_profiling::command_submit>(
this->getHandleRef(), this->getPlugin());
this->getHandle(), this->getPlugin());
}

// The delay between the submission and the actual start of a CommandBuffer
Expand All @@ -311,10 +317,11 @@ event_impl::get_profiling_info<info::event_profiling::command_submit>() {
// made by forcing the re-sync of submit time to start time is less than
// 0.5ms. These timing values were obtained empirically using an integrated
// Intel GPU).
if (MEventFromSubmittedExecCommandBuffer && !MIsHostEvent && MEvent) {
auto Handle = this->getHandle();
if (MEventFromSubmittedExecCommandBuffer && !MIsHostEvent && Handle) {
uint64_t StartTime =
get_event_profiling_info<info::event_profiling::command_start>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
if (StartTime < MSubmitTime)
MSubmitTime = StartTime;
}
Expand All @@ -326,16 +333,17 @@ uint64_t
event_impl::get_profiling_info<info::event_profiling::command_start>() {
checkProfilingPreconditions();
if (!MIsHostEvent) {
if (MEvent) {
auto Handle = getHandle();
if (Handle) {
auto StartTime =
get_event_profiling_info<info::event_profiling::command_start>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
if (!MFallbackProfiling) {
return StartTime;
} else {
auto DeviceBaseTime =
get_event_profiling_info<info::event_profiling::command_submit>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
return MHostBaseTime - DeviceBaseTime + StartTime;
}
}
Expand All @@ -353,16 +361,17 @@ template <>
uint64_t event_impl::get_profiling_info<info::event_profiling::command_end>() {
checkProfilingPreconditions();
if (!MIsHostEvent) {
if (MEvent) {
auto Handle = this->getHandle();
if (Handle) {
auto EndTime =
get_event_profiling_info<info::event_profiling::command_end>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
if (!MFallbackProfiling) {
return EndTime;
} else {
auto DeviceBaseTime =
get_event_profiling_info<info::event_profiling::command_submit>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
return MHostBaseTime - DeviceBaseTime + EndTime;
}
}
Expand All @@ -377,8 +386,9 @@ uint64_t event_impl::get_profiling_info<info::event_profiling::command_end>() {
}

template <> uint32_t event_impl::get_info<info::event::reference_count>() {
if (!MIsHostEvent && MEvent) {
return get_event_info<info::event::reference_count>(this->getHandleRef(),
auto Handle = this->getHandle();
if (!MIsHostEvent && Handle) {
return get_event_info<info::event::reference_count>(Handle,
this->getPlugin());
}
return 0;
Expand All @@ -392,9 +402,10 @@ event_impl::get_info<info::event::command_execution_status>() {

if (!MIsHostEvent) {
// Command is enqueued and UrEvent is ready
if (MEvent)
auto Handle = this->getHandle();
if (Handle)
return get_event_info<info::event::command_execution_status>(
this->getHandleRef(), this->getPlugin());
Handle, this->getPlugin());
// Command is blocked and not enqueued, UrEvent is not assigned yet
else if (MCommand)
return sycl::info::event_command_status::submitted;
Expand Down Expand Up @@ -471,17 +482,20 @@ ur_native_handle_t event_impl::getNative() {
initContextIfNeeded();

auto Plugin = getPlugin();
if (MIsDefaultConstructed && !MEvent) {
auto Handle = getHandle();
if (MIsDefaultConstructed && !Handle) {
auto TempContext = MContext.get()->getHandleRef();
ur_event_native_properties_t NativeProperties{};
ur_event_handle_t UREvent = nullptr;
Plugin->call(urEventCreateWithNativeHandle, 0, TempContext,
&NativeProperties, &MEvent);
&NativeProperties, &UREvent);
this->setHandle(UREvent);
}
if (MContext->getBackend() == backend::opencl)
Plugin->call(urEventRetain, getHandleRef());
ur_native_handle_t Handle;
Plugin->call(urEventGetNativeHandle, getHandleRef(), &Handle);
return Handle;
Plugin->call(urEventRetain, Handle);
ur_native_handle_t OutHandle;
Plugin->call(urEventGetNativeHandle, Handle, &OutHandle);
return OutHandle;
}

std::vector<EventImplPtr> event_impl::getWaitList() {
Expand All @@ -505,7 +519,8 @@ std::vector<EventImplPtr> event_impl::getWaitList() {
void event_impl::flushIfNeeded(const QueueImplPtr &UserQueue) {
// Some events might not have a native handle underneath even at this point,
// e.g. those produced by memset with 0 size (no UR call is made).
if (MIsFlushed || !MEvent)
auto Handle = this->getHandle();
if (MIsFlushed || !Handle)
return;

QueueImplPtr Queue = MQueue.lock();
Expand All @@ -520,7 +535,7 @@ void event_impl::flushIfNeeded(const QueueImplPtr &UserQueue) {

// Check if the task for this event has already been submitted.
ur_event_status_t Status = UR_EVENT_STATUS_QUEUED;
getPlugin()->call(urEventGetInfo, MEvent,
getPlugin()->call(urEventGetInfo, Handle,
UR_EVENT_INFO_COMMAND_EXECUTION_STATUS,
sizeof(ur_event_status_t), &Status, nullptr);
if (Status == UR_EVENT_STATUS_QUEUED) {
Expand Down
19 changes: 7 additions & 12 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,11 @@ class event_impl {
/// Marks this event as completed.
void setComplete();

/// Returns raw interoperability event handle. Returned reference will be
/// invalid if event_impl was destroyed.
///
/// \return a reference to an instance of plug-in event handle.
ur_event_handle_t &getHandleRef();
/// Returns raw interoperability event handle. Returned reference will be
/// invalid if event_impl was destroyed.
///
/// \return a const reference to an instance of plug-in event handle.
const ur_event_handle_t &getHandleRef() const;
/// Returns raw interoperability event handle.
ur_event_handle_t getHandle() const;

/// Set event handle for this event object.
void setHandle(const ur_event_handle_t &UREvent);

/// Returns context that is associated with this event.
///
Expand Down Expand Up @@ -240,7 +235,7 @@ class event_impl {
/// have native handle.
///
/// @return true if no associated command and no event handle.
bool isNOP() { return !MCommand && !getHandleRef(); }
bool isNOP() { return !MCommand && !getHandle(); }

/// Calling this function queries the current device timestamp and sets it as
/// submission time for the command associated with this event.
Expand Down Expand Up @@ -344,7 +339,7 @@ class event_impl {
int32_t StreamID, uint64_t IId) const;
void checkProfilingPreconditions() const;

ur_event_handle_t MEvent = nullptr;
std::atomic<ur_event_handle_t> MEvent = nullptr;
// Stores submission time of command associated with event
uint64_t MSubmitTime = 0;
uint64_t MHostBaseTime = 0;
Expand Down
5 changes: 3 additions & 2 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
}

NewEvent = CreateNewEvent();
ur_event_handle_t *OutEvent = &NewEvent->getHandleRef();
ur_event_handle_t UREvent = nullptr;
// Merge requirements from the nodes into requirements (if any) from the
// handler.
CGData.MRequirements.insert(CGData.MRequirements.end(),
Expand All @@ -927,7 +927,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
}
ur_result_t Res = Queue->getPlugin()->call_nocheck(
urCommandBufferEnqueueExp, CommandBuffer, Queue->getHandleRef(), 0,
nullptr, OutEvent);
nullptr, &UREvent);
NewEvent->setHandle(UREvent);
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
throw sycl::exception(
make_error_code(errc::invalid),
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,9 @@ static void waitForEvents(const std::vector<EventImplPtr> &Events) {
if (!Events.empty()) {
const PluginPtr &Plugin = Events[0]->getPlugin();
std::vector<ur_event_handle_t> UrEvents(Events.size());
std::transform(Events.begin(), Events.end(), UrEvents.begin(),
[](const EventImplPtr &EventImpl) {
return EventImpl->getHandleRef();
});
std::transform(
Events.begin(), Events.end(), UrEvents.begin(),
[](const EventImplPtr &EventImpl) { return EventImpl->getHandle(); });
if (!UrEvents.empty() && UrEvents[0]) {
Plugin->call(urEventWait, UrEvents.size(), &UrEvents[0]);
}
Expand Down Expand Up @@ -313,7 +312,7 @@ void *MemoryManager::allocateInteropMemObject(
// If memory object is created with interop c'tor return cl_mem as is.
assert(TargetContext == InteropContext && "Expected matching contexts");

OutEventToWait = InteropEvent->getHandleRef();
OutEventToWait = InteropEvent->getHandle();
// Retain the event since it will be released during alloca command
// destruction
if (nullptr != OutEventToWait) {
Expand Down
15 changes: 9 additions & 6 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ getUrEvents(const std::vector<sycl::event> &DepEvents) {
std::vector<ur_event_handle_t> RetUrEvents;
for (const sycl::event &Event : DepEvents) {
const EventImplPtr &EventImpl = detail::getSyclObjImpl(Event);
if (EventImpl->getHandleRef() != nullptr)
RetUrEvents.push_back(EventImpl->getHandleRef());
auto Handle = EventImpl->getHandle();
if (Handle != nullptr)
RetUrEvents.push_back(Handle);
}
return RetUrEvents;
}
Expand Down Expand Up @@ -307,7 +308,7 @@ void queue_impl::addEvent(const event &Event) {
}
// As long as the queue supports urQueueFinish we only need to store events
// for unenqueued commands and host tasks.
else if (MEmulateOOO || EImpl->getHandleRef() == nullptr) {
else if (MEmulateOOO || EImpl->getHandle() == nullptr) {
std::weak_ptr<event_impl> EventWeakPtr{EImpl};
std::lock_guard<std::mutex> Lock{MMutex};
MEventsWeak.push_back(std::move(EventWeakPtr));
Expand Down Expand Up @@ -447,8 +448,10 @@ event queue_impl::submitMemOpHelper(const std::shared_ptr<queue_impl> &Self,
auto EventImpl = detail::getSyclObjImpl(ResEvent);
{
NestedCallsTracker tracker;
MemOpFunc(MemOpArgs..., getUrEvents(ExpandedDepEvents),
&EventImpl->getHandleRef(), EventImpl);
ur_event_handle_t UREvent = nullptr;
MemOpFunc(MemOpArgs..., getUrEvents(ExpandedDepEvents), &UREvent,
EventImpl);
EventImpl->setHandle(UREvent);
}

if (isInOrder()) {
Expand Down Expand Up @@ -603,7 +606,7 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
EventImplWeakPtrIt->lock()) {
// A nullptr UR event indicates that urQueueFinish will not cover it,
// either because it's a host task event or an unenqueued one.
if (!SupportsPiFinish || nullptr == EventImplSharedPtr->getHandleRef()) {
if (!SupportsPiFinish || nullptr == EventImplSharedPtr->getHandle()) {
EventImplSharedPtr->wait(EventImplSharedPtr);
}
}
Expand Down
5 changes: 3 additions & 2 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,10 @@ class queue_impl {
template <typename HandlerType = handler>
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
auto ResEvent = std::make_shared<detail::event_impl>(Handler.MQueue);
ur_event_handle_t UREvent = nullptr;
getPlugin()->call(urEnqueueEventsWaitWithBarrier,
Handler.MQueue->getHandleRef(), 0, nullptr,
&ResEvent->getHandleRef());
Handler.MQueue->getHandleRef(), 0, nullptr, &UREvent);
ResEvent->setHandle(UREvent);
return ResEvent;
}

Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
auto EventImpl = std::make_shared<detail::event_impl>(Queue);
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
EventImpl->setStateIncomplete();
MemoryManager::fill_usm(Counter.get(), Queue, sizeof(int), {0}, {},
&EventImpl->getHandleRef(), EventImpl);
ur_event_handle_t UREvent = nullptr;
MemoryManager::fill_usm(Counter.get(), Queue, sizeof(int), {0}, {}, &UREvent,
EventImpl);
EventImpl->setHandle(UREvent);
CGH.depends_on(createSyclObjFromImpl<event>(EventImpl));
}

Expand Down
Loading

0 comments on commit a689b8d

Please sign in to comment.