Skip to content

Commit

Permalink
[core][compiled graphs] allreduce errors with num returns > 1 (#49300)
Browse files Browse the repository at this point in the history
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
  • Loading branch information
kevin85421 authored Dec 19, 2024
1 parent 1b07eaf commit 2dd6061
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _bind(
(node, i),
dict(),
dict(),
{IS_CLASS_METHOD_OUTPUT_KEY: True},
{IS_CLASS_METHOD_OUTPUT_KEY: True, PARENT_CLASS_NODE_KEY: actor},
)
output_nodes.append(output_node)
return tuple(output_nodes)
Expand Down
43 changes: 43 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def recv_tensor(self, tensor):
def ping(self):
return

@ray.method(num_returns=2)
def return_two_tensors(
self, t1: torch.Tensor, t2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return t1, t2


@ray.remote(num_cpus=1)
class TrainWorker:
Expand Down Expand Up @@ -1245,6 +1251,43 @@ def test_torch_tensor_nccl_all_reduce_scheduling(ray_start_regular):
assert result[2] == (value, shape, dtype)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_nccl_all_reduce_with_class_method_output_node(ray_start_regular):
"""
Test all-reduce with class method output node.
"""
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
workers = [actor_cls.remote() for _ in range(num_workers)]

with InputNode() as inp:
t1, t2 = workers[0].return_two_tensors.bind(inp[0], inp[1])
t3, t4 = workers[1].return_two_tensors.bind(inp[2], inp[3])
tensors = collective.allreduce.bind([t1, t4], ReduceOp.SUM)
dag = MultiOutputNode(tensors + [t2, t3])

compiled_dag = dag.experimental_compile()

t1 = torch.tensor([1], device="cuda")
t2 = torch.tensor([2], device="cuda")
t3 = torch.tensor([3], device="cuda")
t4 = torch.tensor([4], device="cuda")

for i in range(3):
i += 1
ref = compiled_dag.execute(t1, t2, t3, t4)
result = ray.get(ref)
assert result == [t1 + t4, t1 + t4, t2, t3]


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 2}], indirect=True)
def test_tensor_writable_warning_suppressed(ray_start_regular):
"""When we move cpu tensor to gpu, aDAG does zero-copy with is_wriatble=False.
Expand Down
11 changes: 10 additions & 1 deletion python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def allreduce(
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")

assert send_buf.dtype == recv_buf.dtype, (
"Ray Compiled Graph derived the dtype of recv_buf from send_buf, "
"so send_buf and recv_buf must have the same dtype. "
"If you see this error, please file an issue at Ray repository."
)
self._comm.allReduce(
self.nccl_util.get_tensor_ptr(send_buf),
self.nccl_util.get_tensor_ptr(recv_buf),
Expand All @@ -278,7 +283,11 @@ def allreduce(
# TODO(wxdeng): Use check_async_error.
self._cuda_stream.synchronize()
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")
raise RayChannelError(
"NCCL group has been destroyed during allreduce operation. "
"There may be a dtype mismatch between input tensors from "
"different ranks."
)

@property
def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]:
Expand Down

0 comments on commit 2dd6061

Please sign in to comment.