Skip to content

Commit

Permalink
Merge pull request #1830 from JackAKirk/hip-set-device
Browse files Browse the repository at this point in the history
[hip] Remove deprecated hip APIs, simplify urContext
  • Loading branch information
aarongreig authored Sep 23, 2024
2 parents 9ca3ec7 + be38e56 commit f5c907a
Show file tree
Hide file tree
Showing 21 changed files with 99 additions and 141 deletions.
2 changes: 1 addition & 1 deletion source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_event_handle_t *phEvent) {
try {
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_guard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down
23 changes: 11 additions & 12 deletions source/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
return nullptr;
}

/// Create a UR HIP context.
///
/// By default creates a scoped context and keeps the last active HIP context
/// on top of the HIP context stack.
/// Create a UR context.
///
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
Expand All @@ -44,7 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
// Create a scoped context.
// Create a context.
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{phDevices, DeviceCount});
*phContext = ContextPtr.release();
Expand Down Expand Up @@ -111,13 +108,15 @@ urContextRetain(ur_context_handle_t hContext) {
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
// FIXME: this entry point has been deprecated in the SYCL RT and should be
// changed to unsupported once the deprecation period has elapsed
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevices()[0]->getNativeContext());
return UR_RESULT_SUCCESS;
// urContextGetNativeHandle should not be implemented in the HIP backend.
// hipCtx_t is not natively supported by amd devices, and more importantly does
// not map to ur_context_handle_t in any way.
UR_APIEXPORT ur_result_t UR_APICALL
urContextGetNativeHandle([[maybe_unused]] ur_context_handle_t hContext,
[[maybe_unused]] ur_native_handle_t *phNativeContext) {
std::ignore = hContext;
std::ignore = phNativeContext;
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
Expand Down
42 changes: 10 additions & 32 deletions source/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
///
/// <b> Destructor callback </b>
///
/// Required to implement CP023, SYCL Extended Context Destruction,
/// the UR Context can store a number of callback functions that will be
/// called upon destruction of the UR Context.
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
/// <b> Memory Management for Devices in a Context <\b>
///
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
Expand All @@ -76,8 +85,6 @@ struct ur_context_handle_t_ {
void operator()() { Function(UserData); }
};

using native_type = hipCtx_t;

std::vector<ur_device_handle_t> Devices;

std::atomic_uint32_t RefCount;
Expand All @@ -89,11 +96,7 @@ struct ur_context_handle_t_ {
}
};

~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}
~ur_context_handle_t_() {}

void invokeExtendedDeleters() {
std::lock_guard<std::mutex> Guard(Mutex);
Expand Down Expand Up @@ -136,28 +139,3 @@ struct ur_context_handle_t_ {
std::vector<deleter_data> ExtendedDeleters;
std::set<ur_usm_pool_handle_t> PoolHandles;
};

namespace {
/// Scoped context is used across all UR HIP plugin implementation to activate
/// the native Context on the current thread. The ScopedContext does not
/// reinstate the previous context as all operations in the hip adapter that
/// require an active context, set the active context and don't rely on context
/// reinstation
class ScopedContext {
public:
ScopedContext(ur_device_handle_t hDevice) {
hipCtx_t Original{};

if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
}
}
};
} // namespace
2 changes: 1 addition & 1 deletion source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;

ur_event_handle_t_::native_type Event;
ScopedContext Active(hDevice);
ScopedDevice Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
Expand Down
33 changes: 22 additions & 11 deletions source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ struct ur_device_handle_t_ {
native_type HIPDevice;
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
hipCtx_t HIPContext;
hipEvent_t EvBase; // HIP event used as base counter
uint32_t DeviceIndex;

Expand All @@ -37,11 +36,10 @@ struct ur_device_handle_t_ {
int ConcurrentManagedAccess{0};

public:
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
hipEvent_t EvBase, ur_platform_handle_t Platform,
uint32_t DeviceIndex)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context), EvBase(EvBase), DeviceIndex(DeviceIndex) {
ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase,
ur_platform_handle_t Platform, uint32_t DeviceIndex)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform), EvBase(EvBase),
DeviceIndex(DeviceIndex) {

UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
Expand All @@ -61,9 +59,7 @@ struct ur_device_handle_t_ {
HIPDevice));
}

~ur_device_handle_t_() noexcept(false) {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
}
~ur_device_handle_t_() noexcept(false) {}

native_type get() const noexcept { return HIPDevice; };

Expand All @@ -73,8 +69,6 @@ struct ur_device_handle_t_ {

uint64_t getElapsedTime(hipEvent_t) const;

hipCtx_t getNativeContext() const noexcept { return HIPContext; };

// Returns the index of the device relative to the other devices in the same
// platform
uint32_t getIndex() const noexcept { return DeviceIndex; };
Expand All @@ -97,3 +91,20 @@ struct ur_device_handle_t_ {
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

namespace {
/// Scoped Device is used across all UR HIP plugin implementation to activate
/// the native Device on the current thread. The ScopedDevice does not
/// reinstate the previous device as all operations in the HIP adapter that
/// require an active device, set the active device and don't rely on device
/// reinstation
class ScopedDevice {
public:
ScopedDevice(ur_device_handle_t hDevice) {
if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}
UR_CHECK_ERROR(hipSetDevice(hDevice->getIndex()));
}
};
} // namespace
44 changes: 22 additions & 22 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
auto Result = forLatestEvents(
EventWaitList, NumEventsInWaitList,
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
ScopedContext Active(Queue->getDevice());
ScopedDevice Active(Queue->getDevice());
if (Event->isCompleted() || Event->getStream() == Stream) {
return UR_RESULT_SUCCESS;
} else {
Expand Down Expand Up @@ -164,7 +164,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
hBuffer->setLastQueueWritingToMemObj(hQueue);

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
Expand Down Expand Up @@ -220,7 +220,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
}

auto Device = hQueue->getDevice();
ScopedContext Active(Device);
ScopedDevice Active(Device);
hipStream_t HIPStream = hQueue->getNextTransferStream();

// Use the default stream if copying from another device
Expand Down Expand Up @@ -290,7 +290,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
pGlobalWorkSize, pLocalWorkSize, hKernel,
HIPFunc, ThreadsPerBlock, BlocksPerGrid));

ScopedContext Active(Dev);
ScopedDevice Active(Dev);

uint32_t StreamToken;
ur_stream_guard Guard;
Expand Down Expand Up @@ -378,7 +378,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST)

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_guard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -533,7 +533,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
}

auto Device = hQueue->getDevice();
ScopedContext Active(Device);
ScopedDevice Active(Device);
hipStream_t HIPStream = hQueue->getNextTransferStream();

UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -582,7 +582,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
hBuffer->setLastQueueWritingToMemObj(hQueue);

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
Expand Down Expand Up @@ -629,7 +629,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
ur_result_t Result = UR_RESULT_SUCCESS;
auto Stream = hQueue->getNextTransferStream();

Expand Down Expand Up @@ -680,7 +680,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -794,7 +794,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
hBuffer->setLastQueueWritingToMemObj(hQueue);

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

auto Stream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Expand Down Expand Up @@ -941,7 +941,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
}

auto Device = hQueue->getDevice();
ScopedContext Active(Device);
ScopedDevice Active(Device);
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -1001,7 +1001,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -1066,7 +1066,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -1161,7 +1161,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
hQueue, hBuffer, blockingMap, offset, size, MapPtr,
numEventsInWaitList, phEventWaitList, phEvent));
} else {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

if (IsPinned) {
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
Expand Down Expand Up @@ -1211,7 +1211,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
hQueue, hMem, true, Map->getMapOffset(), Map->getMapSize(),
pMappedPtr, numEventsInWaitList, phEventWaitList, phEvent));
} else {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

if (IsPinned) {
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
Expand Down Expand Up @@ -1241,7 +1241,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_guard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -1299,7 +1299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1348,7 +1348,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1425,7 +1425,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
#endif

try {
ScopedContext Active(Device);
ScopedDevice Active(Device);
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

if (phEvent) {
Expand Down Expand Up @@ -1561,7 +1561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1762,7 +1762,7 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
size_t MaxWorkGroupSize = 0;
ur_result_t Result = UR_RESULT_SUCCESS;
try {
ScopedContext Active(Device);
ScopedDevice Active(Device);
{
size_t MaxThreadsPerBlock[3] = {
static_cast<size_t>(Device->getMaxBlockDimX()),
Expand Down Expand Up @@ -1906,7 +1906,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueTimestampRecordingExp(
ur_result_t Result = UR_RESULT_SUCCESS;
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

uint32_t StreamToken;
ur_stream_guard Guard;
Expand Down
Loading

0 comments on commit f5c907a

Please sign in to comment.