Skip to content

Commit

Permalink
Use camelCase always. Return caught UR error directly.
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Jan 9, 2024
1 parent 0cdd69f commit 9bd455f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
20 changes: 8 additions & 12 deletions source/adapters/cuda/usm_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,24 @@

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp(
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {

ur_result_t result = UR_RESULT_SUCCESS;
try {
ScopedContext active(commandDevice->getContext());
UR_CHECK_ERROR(cuCtxEnablePeerAccess(peerDevice->getContext(), 0));
} catch (ur_result_t err) {
result = err;
return err;
}
return result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp(
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {

ur_result_t result = UR_RESULT_SUCCESS;
try {
ScopedContext active(commandDevice->getContext());
UR_CHECK_ERROR(cuCtxDisablePeerAccess(peerDevice->getContext()));
} catch (ur_result_t err) {
result = err;
return err;
}
return result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
Expand All @@ -45,16 +41,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

int value;
CUdevice_P2PAttribute cu_attr;
CUdevice_P2PAttribute cuAttr;
try {
ScopedContext active(commandDevice->getContext());
switch (propName) {
case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORTED: {
cu_attr = CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED;
cuAttr = CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED;
break;
}
case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORTED: {
cu_attr = CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED;
cuAttr = CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED;
break;
}
default: {
Expand All @@ -63,7 +59,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
}

UR_CHECK_ERROR(cuDeviceGetP2PAttribute(
&value, cu_attr, commandDevice->get(), peerDevice->get()));
&value, cuAttr, commandDevice->get(), peerDevice->get()));
} catch (ur_result_t err) {
return err;
}
Expand Down
18 changes: 8 additions & 10 deletions source/adapters/hip/usm_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,24 @@

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp(
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {
ur_result_t result = UR_RESULT_SUCCESS;
try {
ScopedContext active(commandDevice);
UR_CHECK_ERROR(hipDeviceEnablePeerAccess(peerDevice->get(), 0));
} catch (ur_result_t err) {
result = err;
return err;
}
return result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp(
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {
ur_result_t result = UR_RESULT_SUCCESS;
try {
ScopedContext active(commandDevice);
UR_CHECK_ERROR(hipDeviceDisablePeerAccess(peerDevice->get()));
} catch (ur_result_t err) {
result = err;
return err;
}
return result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
Expand All @@ -42,24 +40,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

int value;
hipDeviceP2PAttr hip_attr;
hipDeviceP2PAttr hipAttr;
try {
ScopedContext active(commandDevice);
switch (propName) {
case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORTED: {
hip_attr = hipDevP2PAttrAccessSupported;
hipAttr = hipDevP2PAttrAccessSupported;
break;
}
case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORTED: {
hip_attr = hipDevP2PAttrNativeAtomicSupported;
hipAttr = hipDevP2PAttrNativeAtomicSupported;
break;
}
default: {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}
UR_CHECK_ERROR(hipDeviceGetP2PAttribute(
&value, hip_attr, commandDevice->get(), peerDevice->get()));
&value, hipAttr, commandDevice->get(), peerDevice->get()));
} catch (ur_result_t err) {
return err;
}
Expand Down

0 comments on commit 9bd455f

Please sign in to comment.