diff --git a/source/adapters/hip/usm_p2p.cpp b/source/adapters/hip/usm_p2p.cpp index 65635dc910..d0d25c2092 100644 --- a/source/adapters/hip/usm_p2p.cpp +++ b/source/adapters/hip/usm_p2p.cpp @@ -9,25 +9,57 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "context.hpp" -UR_APIEXPORT ur_result_t UR_APICALL -urUsmP2PEnablePeerAccessExp(ur_device_handle_t, ur_device_handle_t) { - detail::ur::die( - "urUsmP2PEnablePeerAccessExp is not implemented for HIP adapter."); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) { + try { + ScopedContext active(commandDevice); + UR_CHECK_ERROR(hipDeviceEnablePeerAccess(peerDevice->get(), 0)); + } catch (ur_result_t err) { + return err; + } + return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL -urUsmP2PDisablePeerAccessExp(ur_device_handle_t, ur_device_handle_t) { - detail::ur::die( - "urUsmP2PDisablePeerAccessExp is not implemented for HIP adapter."); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) { + try { + ScopedContext active(commandDevice); + UR_CHECK_ERROR(hipDeviceDisablePeerAccess(peerDevice->get())); + } catch (ur_result_t err) { + return err; + } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( - ur_device_handle_t, ur_device_handle_t, ur_exp_peer_info_t, size_t propSize, - void *pPropValue, size_t *pPropSizeRet) { + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice, + ur_exp_peer_info_t propName, size_t propSize, void *pPropValue, + size_t *pPropSizeRet) { UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); - // Zero return value indicates that all of the queries currently return false. - return ReturnValue(uint32_t{0}); + + int value; + hipDeviceP2PAttr hipAttr; + try { + ScopedContext active(commandDevice); + switch (propName) { + case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORTED: { + hipAttr = hipDevP2PAttrAccessSupported; + break; + } + case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORTED: { + hipAttr = hipDevP2PAttrNativeAtomicSupported; + break; + } + default: { + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + } + UR_CHECK_ERROR(hipDeviceGetP2PAttribute( + &value, hipAttr, commandDevice->get(), peerDevice->get())); + } catch (ur_result_t err) { + return err; + } + return ReturnValue(value); }