Skip to content

Commit

Permalink
Fixes for AM and enable AM tests that were previously skipped (#315)
Browse files Browse the repository at this point in the history
Fix AM receive message cancelation in C++ implementation that requires a custom `cancel()` implementation, as well as log when allocator raised an exception and return `UCXNoMemoryError` to requests where allocation failed. Reimplement Python AM RMM allocator with a pure C++ function to prevent Cython from introducing Python exception handlers that should not occur in the allocators as the C++ backend should not require the GIL.

Enable AM tests that were previously skipped and update `test_send_recv_two_workers` to match AM implementation.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - James Lamb (https://github.com/jameslamb)

URL: #315
  • Loading branch information
pentschev authored Nov 12, 2024
1 parent 7069174 commit eddc5d2
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 95 deletions.
2 changes: 1 addition & 1 deletion ci/test_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ run_py_tests_async() {
ENABLE_PYTHON_FUTURE=$3
SKIP=$4

CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50"
CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow"

if [ $SKIP -ne 0 ]; then
echo -e "\e[1;33mSkipping unstable test: ${CMD_LINE}\e[0m"
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Request : public Component {
/**
* @brief Cancel the request.
*
* Cancel the request. Often called by the an error handler or parent's object
* Cancel the request. Often called by the error handler or parent's object
* destructor but may be called by the user to cancel the request as well.
*/
virtual void cancel();
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/ucxx/request_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ class RequestAm : public Request {
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

/**
* @brief Cancel the request.
*
* Cancel the request. Often called by the error handler or parent's object
* destructor but may be called by the user to cancel the request as well.
*/
void cancel() override;

void populateDelayedSubmission() override;

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts)
if (_endpointErrorHandling)
param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE};

auto worker = ::ucxx::getWorker(_parent);
ucs_status_ptr_t status;
auto worker = ::ucxx::getWorker(_parent);
ucs_status_ptr_t status = nullptr;

if (worker->isProgressThreadRunning()) {
bool closeSuccess = false;
Expand Down
55 changes: 46 additions & 9 deletions cpp/src/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,26 @@ RequestAm::RequestAm(std::shared_ptr<Component> endpointOrWorker,
requestData);
}

void RequestAm::cancel()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (_status == UCS_INPROGRESS) {
/**
* This is needed to ensure AM requests are cancelable, since they do not
* use the `_request`, thus `ucp_request_cancel()` cannot cancel them.
*/
setStatus(UCS_ERR_CANCELED);
} else {
ucxx_trace_req_f(_ownerString.c_str(),
this,
_request,
_operationName.c_str(),
"already completed with status: %d (%s)",
_status,
ucs_status_string(_status));
}
}

static void _amSendCallback(void* request, ucs_status_t status, void* user_data)
{
Request* req = reinterpret_cast<Request*>(user_data);
Expand Down Expand Up @@ -248,19 +268,29 @@ ucs_status_t RequestAm::recvCallback(void* arg,
amHeader.memoryType = UCS_MEMORY_TYPE_HOST;
}

std::shared_ptr<Buffer> buf = amData->_allocators.at(amHeader.memoryType)(length);
try {
buf = amData->_allocators.at(amHeader.memoryType)(length);
} catch (const std::exception& e) {
ucxx_debug("Exception calling allocator: %s", e.what());
}

auto recvAmMessage =
std::make_shared<internal::RecvAmMessage>(amData, ep, req, buf, receiverCallback);

ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL,
.cb = {.recv_am = _recvCompletedCallback},
.user_data = recvAmMessage.get()};
ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL,
.cb = {.recv_am = _recvCompletedCallback},
.user_data = recvAmMessage.get()};

if (buf == nullptr) {
ucxx_debug("Failed to allocate %lu bytes of memory", length);
recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY);
return UCS_ERR_NO_MEMORY;
}

ucs_status_ptr_t status =
ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param);
ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam);

if (req->_enablePythonFuture)
ucxx_trace_req_f(ownerString.c_str(),
Expand Down Expand Up @@ -302,7 +332,15 @@ ucs_status_t RequestAm::recvCallback(void* arg,
return UCS_INPROGRESS;
}
} else {
std::shared_ptr<Buffer> buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length);
buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length);

internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback);
if (buf == nullptr) {
ucxx_debug("Failed to allocate %lu bytes of memory", length);
recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY);
return UCS_ERR_NO_MEMORY;
}

if (length > 0) memcpy(buf->data(), data, length);

if (req->_enablePythonFuture)
Expand All @@ -326,7 +364,6 @@ ucs_status_t RequestAm::recvCallback(void* arg,
buf->data(),
length);

internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback);
recvAmMessage.callback(nullptr, UCS_OK);
return UCS_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_host_buffer(uintptr_t recv_buffer_ptr):
return np.asarray(HostBufferAdapter._from_host_buffer(host_buffer))


cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length):
cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length) noexcept nogil:
cdef shared_ptr[RMMBuffer] rmm_buffer = make_shared[RMMBuffer](length)
return dynamic_pointer_cast[Buffer, RMMBuffer](rmm_buffer)

Expand Down
2 changes: 1 addition & 1 deletion python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ cdef extern from "<ucxx/buffer.h>" namespace "ucxx" nogil:
void* data() except +raise_py_error

cdef cppclass RMMBuffer:
RMMBuffer(const size_t size_t)
RMMBuffer(const size_t size_t) except +raise_py_error
BufferType getType()
size_t getSize()
unique_ptr[device_buffer] release() except +raise_py_error
Expand Down
3 changes: 0 additions & 3 deletions python/ucxx/ucxx/_lib_async/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ async def client_node(port):
@pytest.mark.asyncio
@pytest.mark.parametrize("transfer_api", ["am", "tag", "tag_multi"])
async def test_cancel(transfer_api):
if transfer_api == "am":
pytest.skip("AM not implemented yet")

q = Queue()

async def server_node(ep):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ async def run():
},
)
def test_from_worker_address_error(error_type):
if error_type in ["timeout_am_send", "timeout_am_recv"]:
pytest.skip("AM not implemented yet")

q1 = mp.Queue()
q2 = mp.Queue()

Expand Down
Loading

0 comments on commit eddc5d2

Please sign in to comment.