Skip to content

Commit

Permalink
Merge pull request #984 from al42and/ext_oneapi_queue_priority-hip
Browse files Browse the repository at this point in the history
[HIP] Implement ext_oneapi_queue_priority
  • Loading branch information
kbenzie authored Jan 5, 2024
2 parents c311fe8 + b65f9d7 commit 46a886d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
24 changes: 19 additions & 5 deletions source/adapters/hip/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ hipStream_t ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
// The second check is done after mutex is locked so other threads can not
// change NumComputeStreams after that
if (NumComputeStreams < ComputeStreams.size()) {
UR_CHECK_ERROR(hipStreamCreateWithFlags(
&ComputeStreams[NumComputeStreams++], Flags));
UR_CHECK_ERROR(hipStreamCreateWithPriority(
&ComputeStreams[NumComputeStreams++], Flags, Priority));
}
}
Token = ComputeStreamIdx++;
Expand Down Expand Up @@ -97,8 +97,8 @@ hipStream_t ur_queue_handle_t_::getNextTransferStream() {
// The second check is done after mutex is locked so other threads can not
// change NumTransferStreams after that
if (NumTransferStreams < TransferStreams.size()) {
UR_CHECK_ERROR(hipStreamCreateWithFlags(
&TransferStreams[NumTransferStreams++], Flags));
UR_CHECK_ERROR(hipStreamCreateWithPriority(
&TransferStreams[NumTransferStreams++], Flags, Priority));
}
}
uint32_t Stream_i = TransferStreamIdx++ % TransferStreams.size();
Expand All @@ -118,6 +118,19 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
std::unique_ptr<ur_queue_handle_t_> QueueImpl{nullptr};

unsigned int Flags = 0;
ur_queue_flags_t URFlags = 0;
int Priority = 0; // Not guaranteed, but, in ROCm 5.7, 0 is the default

if (pProps && pProps->stype == UR_STRUCTURE_TYPE_QUEUE_PROPERTIES) {
URFlags = pProps->flags;
if (URFlags & UR_QUEUE_FLAG_PRIORITY_HIGH) {
ScopedContext Active(hContext->getDevice());
UR_CHECK_ERROR(hipDeviceGetStreamPriorityRange(nullptr, &Priority));
} else if (URFlags & UR_QUEUE_FLAG_PRIORITY_LOW) {
ScopedContext Active(hContext->getDevice());
UR_CHECK_ERROR(hipDeviceGetStreamPriorityRange(&Priority, nullptr));
}
}

const bool IsOutOfOrder =
pProps ? pProps->flags & UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE
Expand All @@ -130,7 +143,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,

QueueImpl = std::unique_ptr<ur_queue_handle_t_>(new ur_queue_handle_t_{
std::move(ComputeHipStreams), std::move(TransferHipStreams), hContext,
hDevice, Flags, pProps ? pProps->flags : 0});
hDevice, Flags, pProps ? pProps->flags : 0, Priority});

*phQueue = QueueImpl.release();

Expand Down Expand Up @@ -293,6 +306,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
hDevice,
HIPFlags,
Flags,
/*priority*/ 0,
/*backend_owns*/ pProperties->isNativeHandleOwned};
(*phQueue)->NumComputeStreams = 1;

Expand Down
5 changes: 3 additions & 2 deletions source/adapters/hip/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct ur_queue_handle_t_ {
unsigned int LastSyncTransferStreams;
unsigned int Flags;
ur_queue_flags_t URFlags;
int Priority;
// When ComputeStreamSyncMutex and ComputeStreamMutex both need to be
// locked at the same time, ComputeStreamSyncMutex should be locked first
// to avoid deadlocks
Expand All @@ -56,7 +57,7 @@ struct ur_queue_handle_t_ {
ur_queue_handle_t_(std::vector<native_type> &&ComputeStreams,
std::vector<native_type> &&TransferStreams,
ur_context_handle_t Context, ur_device_handle_t Device,
unsigned int Flags, ur_queue_flags_t URFlags,
unsigned int Flags, ur_queue_flags_t URFlags, int Priority,
bool BackendOwns = true)
: ComputeStreams{std::move(ComputeStreams)}, TransferStreams{std::move(
TransferStreams)},
Expand All @@ -66,7 +67,7 @@ struct ur_queue_handle_t_ {
Device{Device}, RefCount{1}, EventCount{0}, ComputeStreamIdx{0},
TransferStreamIdx{0}, NumComputeStreams{0}, NumTransferStreams{0},
LastSyncComputeStreams{0}, LastSyncTransferStreams{0}, Flags(Flags),
URFlags(URFlags), HasOwnership{BackendOwns} {
URFlags(URFlags), Priority(Priority), HasOwnership{BackendOwns} {
urContextRetain(Context);
urDeviceRetain(Device);
}
Expand Down

0 comments on commit 46a886d

Please sign in to comment.