Skip to content

Commit

Permalink
enabling multi-node vllm on ray cluster, tested for meta-llama/Meta-L…
Browse files Browse the repository at this point in the history
…lama-3.1-405B-Instruct. Originally implemented in SW-195705
  • Loading branch information
vishnumadhu365 committed Aug 29, 2024
1 parent 17cd625 commit fa819df
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
17 changes: 16 additions & 1 deletion vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
class RayHabanaExecutor(DistributedGPUExecutor):

uses_ray: bool = True



def _init_executor(self) -> None:
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
Expand Down Expand Up @@ -76,6 +78,12 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]:
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
num_gpus = 1

def retain_envs(var_name):
retain_var_list = ['RAY_DEDUP_LOGS', 'GLOO_SOCKET_IFNAME', 'VLLM_SKIP_WARMUP']
return ('HPU' in var_name or var_name in retain_var_list)



# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
Expand All @@ -94,12 +102,15 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

runtime_env_vars = {k:v for k, v in os.environ.items() if retain_envs(k)}

worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={'HPU': num_gpus},
scheduling_strategy=scheduling_strategy,
runtime_env={"env_vars": runtime_env_vars},
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)

Expand All @@ -115,7 +126,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
#self.workers.append(worker)
if worker_ip == driver_ip:
self.workers.insert(0, worker)
else:
self.workers.append(worker)

if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def __init__(
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.hpu_cache: Optional[List[List[torch.tensor]]] = None
if self.parallel_config.world_size > 8:
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
if 'HABANA_VISIBLE_MODULES' in os.environ:
os.environ.pop('HABANA_VISIBLE_MODULES')
initialize_distributed_hpu(self.parallel_config.world_size, self.rank, self.local_rank)

def _set_env_vars(self):
local_rank = self.local_rank
Expand Down

0 comments on commit fa819df

Please sign in to comment.