diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index acec6c2672cd..c964b890f3d4 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -185,7 +185,7 @@ def do_profile_tasks( """ try: for task in tasks: - task.prepare() + task.prepare(overlap_gpu_communication=overlap_gpu_communication) if not hasattr(self, "__ray_adag_events"): self.__ray_adag_events = [] diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index d1ac1c68063f..1797068e7e2d 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -182,7 +182,11 @@ def test_torch_tensor_as_dag_input(ray_start_regular): @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) -def test_torch_tensor_nccl(ray_start_regular): +@pytest.mark.parametrize("enable_profiling", [False, True]) +@pytest.mark.parametrize("overlap_gpu_communication", [False, True]) +def test_torch_tensor_nccl( + ray_start_regular, monkeypatch, enable_profiling, overlap_gpu_communication +): if not USE_GPU: pytest.skip("NCCL tests require GPUs") @@ -190,6 +194,10 @@ def test_torch_tensor_nccl(ray_start_regular): sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 ), "This test requires at least 2 GPUs" + monkeypatch.setattr( + ray.dag.constants, "RAY_ADAG_ENABLE_PROFILING", enable_profiling + ) + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) sender = actor_cls.remote() @@ -204,7 +212,9 @@ def test_torch_tensor_nccl(ray_start_regular): dag = dag.with_type_hint(TorchTensorType(transport="nccl")) dag = receiver.recv.bind(dag) - compiled_dag = dag.experimental_compile() + compiled_dag = dag.experimental_compile( + _overlap_gpu_communication=overlap_gpu_communication + ) # Test that we can pass different shapes and data. for i in range(3):