diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 68e3e665d2..454d09bbb7 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -1697,16 +1697,67 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( return Result; } +namespace { + +enum class GlobalVariableCopy { Read, Write }; + +ur_result_t deviceGlobalCopyHelper( + ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, + bool blocking, size_t count, size_t offset, void *ptr, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent, GlobalVariableCopy CopyType) { + // Since HIP requires a the global variable to be referenced by name, we use + // metadata to find the correct name to access it by. + auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name); + if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end()) + return UR_RESULT_ERROR_INVALID_VALUE; + std::string DeviceGlobalName = DeviceGlobalNameIt->second; + + try { + hipDeviceptr_t DeviceGlobal = 0; + size_t DeviceGlobalSize = 0; + UR_CHECK_ERROR(hipModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize, + hProgram->get(), + DeviceGlobalName.c_str())); + + if (offset + count > DeviceGlobalSize) + return UR_RESULT_ERROR_INVALID_VALUE; + + void *pSrc, *pDst; + if (CopyType == GlobalVariableCopy::Write) { + pSrc = ptr; + pDst = reinterpret_cast(DeviceGlobal) + offset; + } else { + pSrc = reinterpret_cast(DeviceGlobal) + offset; + pDst = ptr; + } + return urEnqueueUSMMemcpy(hQueue, blocking, pDst, pSrc, count, + numEventsInWaitList, phEventWaitList, phEvent); + } catch (ur_result_t Err) { + return Err; + } +} +} // namespace + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( - ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t, - const void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, + bool blockingWrite, size_t count, size_t offset, const void *pSrc, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + return deviceGlobalCopyHelper(hQueue, hProgram, name, blockingWrite, count, + offset, const_cast(pSrc), + numEventsInWaitList, phEventWaitList, phEvent, + GlobalVariableCopy::Write); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( - ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t, - void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, + bool blockingRead, size_t count, size_t offset, void *pDst, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + return deviceGlobalCopyHelper( + hQueue, hProgram, name, blockingRead, count, offset, pDst, + numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe( diff --git a/source/adapters/hip/program.cpp b/source/adapters/hip/program.cpp index 9aa64151e0..81f1be1194 100644 --- a/source/adapters/hip/program.cpp +++ b/source/adapters/hip/program.cpp @@ -78,6 +78,15 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog, } // namespace #endif +std::pair +splitMetadataName(const std::string &metadataName) { + size_t splitPos = metadataName.rfind('@'); + if (splitPos == std::string::npos) + return std::make_pair(metadataName, std::string{}); + return std::make_pair(metadataName.substr(0, splitPos), + metadataName.substr(splitPos, metadataName.length())); +} + ur_result_t ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata, size_t Length) { @@ -85,10 +94,19 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata, const ur_program_metadata_t MetadataElement = Metadata[i]; std::string MetadataElementName{MetadataElement.pName}; + auto [Prefix, Tag] = splitMetadataName(MetadataElementName); + if (MetadataElementName == __SYCL_UR_PROGRAM_METADATA_TAG_NEED_FINALIZATION) { assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32); IsRelocatable = MetadataElement.value.data32; + } else if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) { + const char *MetadataValPtr = + reinterpret_cast(MetadataElement.value.pData) + + sizeof(std::uint64_t); + const char *MetadataValPtrEnd = + MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t); + GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd}; } } return UR_RESULT_SUCCESS; diff --git a/source/adapters/hip/program.hpp b/source/adapters/hip/program.hpp index 4b4e5ec878..dbdf9c55c6 100644 --- a/source/adapters/hip/program.hpp +++ b/source/adapters/hip/program.hpp @@ -29,6 +29,8 @@ struct ur_program_handle_t_ { // Metadata bool IsRelocatable = false; + std::unordered_map GlobalIDMD; + constexpr static size_t MAX_LOG_SIZE = 8192u; char ErrorLog[MAX_LOG_SIZE], InfoLog[MAX_LOG_SIZE];