Skip to content

Commit

Permalink
tp fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 13, 2024
1 parent ceca996 commit 1976d75
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
10 changes: 7 additions & 3 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform
from vllm.utils import is_fake_hpu

if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
if not is_fake_hpu():
import habana_frameworks.torch as htorch # noqa: F401


class HpuCommunicator:
Expand All @@ -22,7 +24,8 @@ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x

Expand All @@ -37,7 +40,8 @@ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
dtype=x.dtype,
device=x.device)
# All-gather.
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
Expand Down
12 changes: 7 additions & 5 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, is_fake_hpu, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
resource_name = "HPU" if not is_fake_hpu() else "CPU"
if not bundle.get(resource_name,0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

resources = {'HPU': num_gpus} if not is_fake_hpu() else {}
num_cpus = 0 if not is_fake_hpu() else num_gpus
worker = ray.remote(
num_cpus=0,
num_cpus=num_cpus,
num_gpus=0,
resources={'HPU': num_gpus},
resources=resources,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu
from vllm.utils import get_ip, is_fake_hpu, is_hip, is_hpu, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -97,7 +97,7 @@ def initialize_ray_cluster(
if is_tpu():
device_str = "TPU"
elif is_hpu():
device_str = "HPU"
device_str = "HPU" if not is_fake_hpu() else 'CPU'
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down

0 comments on commit 1976d75

Please sign in to comment.