Skip to content

Commit

Permalink
Merge pull request #1186 from hdelan/device-global-hip
Browse files Browse the repository at this point in the history
[HIP] Add support for global variable read write
  • Loading branch information
aarongreig authored Jan 12, 2024
2 parents 25e0b60 + 45d76b7 commit 79c28d0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
63 changes: 57 additions & 6 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,16 +1697,67 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
return Result;
}

namespace {

enum class GlobalVariableCopy { Read, Write };

ur_result_t deviceGlobalCopyHelper(
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blocking, size_t count, size_t offset, void *ptr,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent, GlobalVariableCopy CopyType) {
// Since HIP requires a the global variable to be referenced by name, we use
// metadata to find the correct name to access it by.
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
return UR_RESULT_ERROR_INVALID_VALUE;
std::string DeviceGlobalName = DeviceGlobalNameIt->second;

try {
hipDeviceptr_t DeviceGlobal = 0;
size_t DeviceGlobalSize = 0;
UR_CHECK_ERROR(hipModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
hProgram->get(),
DeviceGlobalName.c_str()));

if (offset + count > DeviceGlobalSize)
return UR_RESULT_ERROR_INVALID_VALUE;

void *pSrc, *pDst;
if (CopyType == GlobalVariableCopy::Write) {
pSrc = ptr;
pDst = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
} else {
pSrc = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
pDst = ptr;
}
return urEnqueueUSMMemcpy(hQueue, blocking, pDst, pSrc, count,
numEventsInWaitList, phEventWaitList, phEvent);
} catch (ur_result_t Err) {
return Err;
}
}
} // namespace

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t,
const void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
return deviceGlobalCopyHelper(hQueue, hProgram, name, blockingWrite, count,
offset, const_cast<void *>(pSrc),
numEventsInWaitList, phEventWaitList, phEvent,
GlobalVariableCopy::Write);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t,
void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blockingRead, size_t count, size_t offset, void *pDst,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
return deviceGlobalCopyHelper(
hQueue, hProgram, name, blockingRead, count, offset, pDst,
numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
Expand Down
18 changes: 18 additions & 0 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,35 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog,
} // namespace
#endif

std::pair<std::string, std::string>
splitMetadataName(const std::string &metadataName) {
size_t splitPos = metadataName.rfind('@');
if (splitPos == std::string::npos)
return std::make_pair(metadataName, std::string{});
return std::make_pair(metadataName.substr(0, splitPos),
metadataName.substr(splitPos, metadataName.length()));
}

ur_result_t
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
size_t Length) {
for (size_t i = 0; i < Length; ++i) {
const ur_program_metadata_t MetadataElement = Metadata[i];
std::string MetadataElementName{MetadataElement.pName};

auto [Prefix, Tag] = splitMetadataName(MetadataElementName);

if (MetadataElementName ==
__SYCL_UR_PROGRAM_METADATA_TAG_NEED_FINALIZATION) {
assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32);
IsRelocatable = MetadataElement.value.data32;
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
const char *MetadataValPtr =
reinterpret_cast<const char *>(MetadataElement.value.pData) +
sizeof(std::uint64_t);
const char *MetadataValPtrEnd =
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
}
}
return UR_RESULT_SUCCESS;
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/hip/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct ur_program_handle_t_ {
// Metadata
bool IsRelocatable = false;

std::unordered_map<std::string, std::string> GlobalIDMD;

constexpr static size_t MAX_LOG_SIZE = 8192u;

char ErrorLog[MAX_LOG_SIZE], InfoLog[MAX_LOG_SIZE];
Expand Down

0 comments on commit 79c28d0

Please sign in to comment.