From 2dd606198c83ccc86c4dbbef5f59499041e85ba8 Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Thu, 19 Dec 2024 13:40:22 -0800 Subject: [PATCH] [core][compiled graphs] `allreduce` errors with num returns > 1 (#49300) Signed-off-by: Kai-Hsun Chen --- python/ray/actor.py | 2 +- .../experimental/test_torch_tensor_dag.py | 43 +++++++++++++++++++ python/ray/experimental/channel/nccl_group.py | 11 ++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/python/ray/actor.py b/python/ray/actor.py index 08c0a26cf381..450b1c620a1e 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -302,7 +302,7 @@ def _bind( (node, i), dict(), dict(), - {IS_CLASS_METHOD_OUTPUT_KEY: True}, + {IS_CLASS_METHOD_OUTPUT_KEY: True, PARENT_CLASS_NODE_KEY: actor}, ) output_nodes.append(output_node) return tuple(output_nodes) diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index b04e67eddac5..449548a5b269 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -101,6 +101,12 @@ def recv_tensor(self, tensor): def ping(self): return + @ray.method(num_returns=2) + def return_two_tensors( + self, t1: torch.Tensor, t2: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return t1, t2 + @ray.remote(num_cpus=1) class TrainWorker: @@ -1245,6 +1251,43 @@ def test_torch_tensor_nccl_all_reduce_scheduling(ray_start_regular): assert result[2] == (value, shape, dtype) +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_nccl_all_reduce_with_class_method_output_node(ray_start_regular): + """ + Test all-reduce with class method output node. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + with InputNode() as inp: + t1, t2 = workers[0].return_two_tensors.bind(inp[0], inp[1]) + t3, t4 = workers[1].return_two_tensors.bind(inp[2], inp[3]) + tensors = collective.allreduce.bind([t1, t4], ReduceOp.SUM) + dag = MultiOutputNode(tensors + [t2, t3]) + + compiled_dag = dag.experimental_compile() + + t1 = torch.tensor([1], device="cuda") + t2 = torch.tensor([2], device="cuda") + t3 = torch.tensor([3], device="cuda") + t4 = torch.tensor([4], device="cuda") + + for i in range(3): + i += 1 + ref = compiled_dag.execute(t1, t2, t3, t4) + result = ray.get(ref) + assert result == [t1 + t4, t1 + t4, t2, t3] + + @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 2}], indirect=True) def test_tensor_writable_warning_suppressed(ray_start_regular): """When we move cpu tensor to gpu, aDAG does zero-copy with is_wriatble=False. diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index 2246c1996651..fd295425794c 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -262,6 +262,11 @@ def allreduce( if self._closed: raise RayChannelError("NCCL group has been destroyed.") + assert send_buf.dtype == recv_buf.dtype, ( + "Ray Compiled Graph derived the dtype of recv_buf from send_buf, " + "so send_buf and recv_buf must have the same dtype. " + "If you see this error, please file an issue at Ray repository." + ) self._comm.allReduce( self.nccl_util.get_tensor_ptr(send_buf), self.nccl_util.get_tensor_ptr(recv_buf), @@ -278,7 +283,11 @@ def allreduce( # TODO(wxdeng): Use check_async_error. self._cuda_stream.synchronize() if self._closed: - raise RayChannelError("NCCL group has been destroyed.") + raise RayChannelError( + "NCCL group has been destroyed during allreduce operation. " + "There may be a dtype mismatch between input tensors from " + "different ranks." + ) @property def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]: