diff --git a/ci/test_common.sh b/ci/test_common.sh index 48592e36..550ec9cb 100755 --- a/ci/test_common.sh +++ b/ci/test_common.sh @@ -128,7 +128,7 @@ run_py_tests_async() { ENABLE_PYTHON_FUTURE=$3 SKIP=$4 - CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50" + CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow" if [ $SKIP -ne 0 ]; then echo -e "\e[1;33mSkipping unstable test: ${CMD_LINE}\e[0m" diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 2584dda5..0a7b1505 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -120,7 +120,7 @@ class Request : public Component { /** * @brief Cancel the request. * - * Cancel the request. Often called by the an error handler or parent's object + * Cancel the request. Often called by the error handler or parent's object * destructor but may be called by the user to cancel the request as well. */ virtual void cancel(); diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index fb5cfd96..b7937c0c 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -96,6 +96,14 @@ class RequestAm : public Request { RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); + /** + * @brief Cancel the request. + * + * Cancel the request. Often called by the error handler or parent's object + * destructor but may be called by the user to cancel the request as well. + */ + void cancel() override; + void populateDelayedSubmission() override; /** diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 4731b78d..5f5870c3 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -262,8 +262,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) if (_endpointErrorHandling) param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; - auto worker = ::ucxx::getWorker(_parent); - ucs_status_ptr_t status; + auto worker = ::ucxx::getWorker(_parent); + ucs_status_ptr_t status = nullptr; if (worker->isProgressThreadRunning()) { bool closeSuccess = false; diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 713f6bdc..de29de0d 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -154,6 +154,26 @@ RequestAm::RequestAm(std::shared_ptr endpointOrWorker, requestData); } +void RequestAm::cancel() +{ + std::lock_guard lock(_mutex); + if (_status == UCS_INPROGRESS) { + /** + * This is needed to ensure AM requests are cancelable, since they do not + * use the `_request`, thus `ucp_request_cancel()` cannot cancel them. + */ + setStatus(UCS_ERR_CANCELED); + } else { + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "already completed with status: %d (%s)", + _status, + ucs_status_string(_status)); + } +} + static void _amSendCallback(void* request, ucs_status_t status, void* user_data) { Request* req = reinterpret_cast(user_data); @@ -248,19 +268,29 @@ ucs_status_t RequestAm::recvCallback(void* arg, amHeader.memoryType = UCS_MEMORY_TYPE_HOST; } - std::shared_ptr buf = amData->_allocators.at(amHeader.memoryType)(length); + try { + buf = amData->_allocators.at(amHeader.memoryType)(length); + } catch (const std::exception& e) { + ucxx_debug("Exception calling allocator: %s", e.what()); + } auto recvAmMessage = std::make_shared(amData, ep, req, buf, receiverCallback); - ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA | - UCP_OP_ATTR_FLAG_NO_IMM_CMPL, - .cb = {.recv_am = _recvCompletedCallback}, - .user_data = recvAmMessage.get()}; + ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL, + .cb = {.recv_am = _recvCompletedCallback}, + .user_data = recvAmMessage.get()}; + + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; + } ucs_status_ptr_t status = - ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param); + ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); if (req->_enablePythonFuture) ucxx_trace_req_f(ownerString.c_str(), @@ -302,7 +332,15 @@ ucs_status_t RequestAm::recvCallback(void* arg, return UCS_INPROGRESS; } } else { - std::shared_ptr buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); + buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); + + internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; + } + if (length > 0) memcpy(buf->data(), data, length); if (req->_enablePythonFuture) @@ -326,7 +364,6 @@ ucs_status_t RequestAm::recvCallback(void* arg, buf->data(), length); - internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); recvAmMessage.callback(nullptr, UCS_OK); return UCS_OK; } diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 2455a781..09bf68d4 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -95,7 +95,7 @@ def _get_host_buffer(uintptr_t recv_buffer_ptr): return np.asarray(HostBufferAdapter._from_host_buffer(host_buffer)) -cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length): +cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length) noexcept nogil: cdef shared_ptr[RMMBuffer] rmm_buffer = make_shared[RMMBuffer](length) return dynamic_pointer_cast[Buffer, RMMBuffer](rmm_buffer) diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index 9c30a4c3..ec8cef1a 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -155,7 +155,7 @@ cdef extern from "" namespace "ucxx" nogil: void* data() except +raise_py_error cdef cppclass RMMBuffer: - RMMBuffer(const size_t size_t) + RMMBuffer(const size_t size_t) except +raise_py_error BufferType getType() size_t getSize() unique_ptr[device_buffer] release() except +raise_py_error diff --git a/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py b/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py index abd82669..5b127fa9 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py @@ -58,9 +58,6 @@ async def client_node(port): @pytest.mark.asyncio @pytest.mark.parametrize("transfer_api", ["am", "tag", "tag_multi"]) async def test_cancel(transfer_api): - if transfer_api == "am": - pytest.skip("AM not implemented yet") - q = Queue() async def server_node(ep): diff --git a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py index 5832fa46..800d5879 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py @@ -162,9 +162,6 @@ async def run(): }, ) def test_from_worker_address_error(error_type): - if error_type in ["timeout_am_send", "timeout_am_recv"]: - pytest.skip("AM not implemented yet") - q1 = mp.Queue() q2 = mp.Queue() diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py index 6bd0a6c2..e49783c6 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py @@ -9,7 +9,7 @@ import numpy as np import pytest -import ucxx as ucxx +import ucxx from ucxx._lib_async.utils import get_event_loop from ucxx._lib_async.utils_test import ( am_recv, @@ -27,49 +27,54 @@ distributed = pytest.importorskip("distributed") cloudpickle = pytest.importorskip("cloudpickle") +# Enable for additional debug output +VERBOSE = False + ITERATIONS = 30 +def print_with_pid(msg): + if VERBOSE: + print(f"[{os.getpid()}] {msg}") + + async def get_ep(name, port): addr = ucxx.get_address() ep = await ucxx.create_endpoint(addr, port) return ep -def register_am_allocators(): - ucxx.register_am_allocator(lambda n: np.empty(n, dtype=np.uint8), "host") - ucxx.register_am_allocator(lambda n: rmm.DeviceBuffer(size=n), "cuda") - - def client(port, func, comm_api): - # wait for server to come up - # receive cudf object - # deserialize - # assert deserialized msg is cdf - # send receipt + # 1. Wait for server to come up + # 2. Loop receiving object multiple times from server + # 3. Send close message + # 4. Assert last received message has correct content from distributed.utils import nbytes - ucxx.init() - - if comm_api == "am": - register_am_allocators() - # must create context before importing # cudf/cupy/etc + ucxx.init() + async def read(): await asyncio.sleep(1) ep = await get_ep("client", port) - msg = None - import cupy - cupy.cuda.set_allocator(None) for i in range(ITERATIONS): - print(f"Client iteration {i}") + print_with_pid(f"Client iteration {i}") if comm_api == "tag": frames, msg = await recv(ep) else: - frames, msg = await am_recv(ep) + while True: + try: + frames, msg = await am_recv(ep) + except ucxx.exceptions.UCXNoMemoryError as e: + # Client didn't receive/consume messages quickly enough, + # new AM failed to allocate memory and raised this + # exception, we need to keep trying. + print_with_pid(f"Client exception: {type(e)} {e}") + else: + break close_msg = b"shutdown listener" @@ -81,13 +86,13 @@ async def read(): else: await ep.am_send(close_msg) - print("Shutting Down Client...") + print_with_pid("Shutting Down Client...") return msg["data"] rx_cuda_obj = get_event_loop().run_until_complete(read()) rx_cuda_obj + rx_cuda_obj num_bytes = nbytes(rx_cuda_obj) - print(f"TOTAL DATA RECEIVED: {num_bytes}") + print_with_pid(f"TOTAL DATA RECEIVED: {num_bytes}") cuda_obj_generator = cloudpickle.loads(func) pure_cuda_obj = cuda_obj_generator() @@ -101,39 +106,39 @@ async def read(): def server(port, func, comm_api): - # create listener receiver - # write cudf object - # confirm message is sent correctly + # 1. Create listener receiver + # 2. Loop sending object multiple times to connected client + # 3. Receive close message and close listener from distributed.comm.utils import to_frames from distributed.protocol import to_serialize ucxx.init() - if comm_api == "am": - register_am_allocators() - async def f(listener_port): - # coroutine shows up when the client asks - # to connect + # Coroutine shows up when the client asks to connect async def write(ep): - import cupy - - cupy.cuda.set_allocator(None) - - print("CREATING CUDA OBJECT IN SERVER...") + print_with_pid("CREATING CUDA OBJECT IN SERVER...") cuda_obj_generator = cloudpickle.loads(func) cuda_obj = cuda_obj_generator() msg = {"data": to_serialize(cuda_obj)} frames = await to_frames(msg, serializers=("cuda", "dask", "pickle")) for i in range(ITERATIONS): - print(f"Server iteration {i}") + print_with_pid(f"Server iteration {i}") # Send meta data if comm_api == "tag": await send(ep, frames) else: - await am_send(ep, frames) - - print("CONFIRM RECEIPT") + while True: + try: + await am_send(ep, frames) + except ucxx.exceptions.UCXNoMemoryError as e: + # Memory pressure due to client taking too long to + # receive will raise an exception. + print_with_pid(f"Listener exception: {type(e)} {e}") + else: + break + + print_with_pid("CONFIRM RECEIPT") close_msg = b"shutdown listener" if comm_api == "tag": @@ -147,7 +152,7 @@ async def write(ep): recv_msg = msg.tobytes() assert recv_msg == close_msg - print("Shutting Down Server...") + print_with_pid("Shutting Down Server...") await ep.close() lf.close() @@ -156,10 +161,8 @@ async def write(ep): try: while not lf.closed: await asyncio.sleep(0.1) - # except ucxx.UCXCloseError: - # pass - except Exception as e: - print(f"Exception: {e=}") + except ucxx.UCXCloseError: + pass loop = get_event_loop() loop.run_until_complete(f(port)) @@ -199,33 +202,28 @@ def cupy_obj(): @pytest.mark.slow -@pytest.mark.skipif( - get_num_gpus() <= 2, reason="Machine does not have more than two GPUs" -) +@pytest.mark.skipif(get_num_gpus() <= 2, reason="Machine needs at least two GPUs") @pytest.mark.parametrize( "cuda_obj_generator", [dataframe, empty_dataframe, series, cupy_obj] ) @pytest.mark.parametrize("comm_api", ["tag", "am"]) def test_send_recv_cu(cuda_obj_generator, comm_api): - if comm_api == "am": - pytest.skip("AM not implemented yet") - base_env = os.environ env_client = base_env.copy() - # grab first two devices + # Grab first two devices cvd = get_cuda_devices()[:2] cvd = ",".join(map(str, cvd)) - # reverse CVD for other worker + # Reverse CVD for client env_client["CUDA_VISIBLE_DEVICES"] = cvd[::-1] port = random.randint(13000, 15500) - # serialize function and send to the client and server - # server will use the return value of the contents, - # serialize the values, then send serialized values to client. - # client will compare return values of the deserialized - # data sent from the server + # Serialize function and send to the client and server. The server will use + # the return value of the contents, serialize the values, then send + # serialized values to client. The client will compare return values of the + # deserialized data sent from the server. func = cloudpickle.dumps(cuda_obj_generator) + ctx = multiprocessing.get_context("spawn") server_process = ctx.Process( name="server", target=server, args=[port, func, comm_api] @@ -235,12 +233,12 @@ def test_send_recv_cu(cuda_obj_generator, comm_api): ) server_process.start() - # cudf will ping the driver for validity of device - # this will influence device on which a cuda context is created. - # work around is to update env with new CVD before spawning + # cuDF will ping the driver for validity of device, this will influence + # device on which a cuda context is created. Workaround is to update + # env with new CVD before spawning os.environ.update(env_client) client_process.start() - join_processes([client, server], timeout=30) - terminate_process(client) - terminate_process(server) + join_processes([client_process, server_process], timeout=3000) + terminate_process(client_process) + terminate_process(server_process) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py b/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py index 163c4fc3..bc39d05a 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py @@ -35,8 +35,6 @@ async def _shutdown_recv(ep, message_type): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_server_shutdown(message_type): """The server calls shutdown""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): with pytest.raises(ucxx.exceptions.UCXCanceledError): @@ -67,8 +65,6 @@ async def client_node(port): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_client_shutdown(message_type): """The client calls shutdown""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def client_node(port): ep = await ucxx.create_endpoint( @@ -96,8 +92,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_listener_close(message_type): """The server close the listener""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def client_node(listener): ep = await ucxx.create_endpoint( @@ -125,8 +119,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_listener_del(message_type): """The client delete the listener""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): await _shutdown_send(ep, message_type) @@ -156,8 +148,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_close_after_n_recv(message_type): """The Endpoint.close_after_n_recv()""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): for _ in range(10):