Skip to content

Commit

Permalink
Revert add prefetch for USM hip allocations a6b8fa66b537753415d24076f…
Browse files Browse the repository at this point in the history
…1025c040110c332
  • Loading branch information
hdelan committed Nov 21, 2023
1 parent 04799e7 commit 841a287
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 79 deletions.
53 changes: 0 additions & 53 deletions source/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#pragma once

#include <set>
#include <unordered_map>

#include "common.hpp"
#include "device.hpp"
Expand Down Expand Up @@ -106,61 +105,9 @@ struct ur_context_handle_t_ {

ur_usm_pool_handle_t getOwningURPool(umf_memory_pool_t *UMFPool);

/// We need to keep track of USM mappings in AMD HIP, as certain extra
/// synchronization *is* actually required for correctness.
/// During kernel enqueue we must dispatch a prefetch for each kernel argument
/// that points to a USM mapping to ensure the mapping is correctly
/// populated on the device (https://github.com/intel/llvm/issues/7252). Thus,
/// we keep track of mappings in the context, and then check against them just
/// before the kernel is launched. The stream against which the kernel is
/// launched is not known until enqueue time, but the USM mappings can happen
/// at any time. Thus, they are tracked on the context used for the urUSM*
/// mapping.
///
/// The three utility function are simple wrappers around a mapping from a
/// pointer to a size.
void addUSMMapping(void *Ptr, size_t Size) {
std::lock_guard<std::mutex> Guard(Mutex);
assert(USMMappings.find(Ptr) == USMMappings.end() &&
"mapping already exists");
USMMappings[Ptr] = Size;
}

void removeUSMMapping(const void *Ptr) {
std::lock_guard<std::mutex> guard(Mutex);
auto It = USMMappings.find(Ptr);
if (It != USMMappings.end())
USMMappings.erase(It);
}

std::pair<const void *, size_t> getUSMMapping(const void *Ptr) {
std::lock_guard<std::mutex> Guard(Mutex);
auto It = USMMappings.find(Ptr);
// The simple case is the fast case...
if (It != USMMappings.end())
return *It;

// ... but in the failure case we have to fall back to a full scan to search
// for "offset" pointers in case the user passes in the middle of an
// allocation. We have to do some not-so-ordained-by-the-standard ordered
// comparisons of pointers here, but it'll work on all platforms we support.
uintptr_t PtrVal = (uintptr_t)Ptr;
for (std::pair<const void *, size_t> Pair : USMMappings) {
uintptr_t BaseAddr = (uintptr_t)Pair.first;
uintptr_t EndAddr = BaseAddr + Pair.second;
if (PtrVal > BaseAddr && PtrVal < EndAddr) {
// If we've found something now, offset *must* be nonzero
assert(Pair.second);
return Pair;
}
}
return {nullptr, 0};
}

private:
std::mutex Mutex;
std::vector<deleter_data> ExtendedDeleters;
std::unordered_map<const void *, size_t> USMMappings;
std::set<ur_usm_pool_handle_t> PoolHandles;
};

Expand Down
11 changes: 1 addition & 10 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
try {
ur_device_handle_t Dev = hQueue->getDevice();
ScopedContext Active(Dev);
ur_context_handle_t Ctx = hQueue->getContext();

uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
hipFunction_t HIPFunc = hKernel->get();

hipDevice_t HIPDev = Dev->get();
for (const void *P : hKernel->getPtrArgs()) {
auto [Addr, Size] = Ctx->getUSMMapping(P);
if (!Addr)
continue;
if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess)
return UR_RESULT_ERROR_INVALID_KERNEL_ARGS;
}
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);

Expand Down Expand Up @@ -315,7 +306,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
int DeviceMaxLocalMem = 0;
UR_CHECK_ERROR(hipDeviceGetAttribute(
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
HIPDev));
Dev->get()));

static const int EnvVal = std::atoi(LocalMemSzPtr);
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/hip/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
ur_kernel_handle_t hKernel, uint32_t argIndex,
const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {
hKernel->setKernelPtrArg(argIndex, sizeof(pArgValue), pArgValue);
hKernel->setKernelArg(argIndex, sizeof(pArgValue), pArgValue);
return UR_RESULT_SUCCESS;
}

Expand Down
15 changes: 0 additions & 15 deletions source/adapters/hip/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <atomic>
#include <cassert>
#include <numeric>
#include <set>

#include "program.hpp"

Expand Down Expand Up @@ -58,7 +57,6 @@ struct ur_kernel_handle_t_ {
args_size_t ParamSizes;
args_index_t Indices;
args_size_t OffsetPerIndex;
std::set<const void *> PtrArgs;

std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};

Expand Down Expand Up @@ -179,19 +177,6 @@ struct ur_kernel_handle_t_ {
Args.addArg(Index, Size, Arg);
}

/// We track all pointer arguments to be able to issue prefetches at enqueue
/// time
void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) {
Args.PtrArgs.insert(*static_cast<void *const *>(PtrArg));
setKernelArg(Index, Size, PtrArg);
}

bool isPtrArg(const void *ptr) {
return Args.PtrArgs.find(ptr) != Args.PtrArgs.end();
}

std::set<const void *> &getPtrArgs() { return Args.PtrArgs; }

void setKernelLocalArg(int Index, size_t Size) {
Args.addLocalArg(Index, Size);
}
Expand Down

0 comments on commit 841a287

Please sign in to comment.