diff --git a/.github/workflows/cpu-test.yml b/.github/workflows/cpu-test.yml new file mode 100644 index 0000000000000..53638d30980d8 --- /dev/null +++ b/.github/workflows/cpu-test.yml @@ -0,0 +1,34 @@ +name: cpu-test + +on: + # Trigger the workflow on push or pull request, + # but only for the habana_main branch + push: + branches: + - habana_main + pull_request: + branches: + - habana_main + + +jobs: + cputest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r requirements-hpu.txt + VLLM_TARGET_DEVICE=hpu python setup.py develop + - name: cpu-test + run: | + VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 python examples/offline_inference_fakehpu.py diff --git a/examples/offline_inference_fakehpu.py b/examples/offline_inference_fakehpu.py new file mode 100644 index 0000000000000..e1b2d611a7a8d --- /dev/null +++ b/examples/offline_inference_fakehpu.py @@ -0,0 +1,33 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Berlin is the capital city of ", + "Louvre is located in the city called ", + "Barack Obama was the 44th president of ", + "Warsaw is the capital city of ", + "Gniezno is a city in ", + "Hebrew is an official state language of ", + "San Francisco is located in the state of ", + "Llanfairpwllgwyngyll is located in country of ", +] +ref_answers = [ + "Germany", "Paris", "United States", "Poland", "Poland", "Israel", + "California", "Wales" +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output, answer in zip(outputs, ref_answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert answer in generated_text, ( + f"The generated text does not contain the correct answer: {answer}") +print('PASSED') diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py index cc9b19ce022b5..e68279ffc42d9 100644 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -3,8 +3,9 @@ from torch.distributed import ProcessGroup from vllm.platforms import current_platform +from vllm.utils import is_fake_hpu -if current_platform.is_hpu(): +if current_platform.is_hpu() and not is_fake_hpu(): import habana_frameworks.torch as htorch # noqa: F401 @@ -22,7 +23,8 @@ def all_reduce(self, x: torch.Tensor) -> torch.Tensor: # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used # (which is required for tensor parallel HPUGraph inference) - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() dist.all_reduce(x, group=self.group) return x @@ -37,7 +39,8 @@ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: dtype=x.dtype, device=x.device) # All-gather. - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() dist.all_gather_into_tensor(output_tensor, x, group=self.group) # Reshape output_tensor = output_tensor.movedim(0, dim) diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index 9e0a89cbeb8aa..c45513e3e5c91 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -13,7 +13,7 @@ from vllm.utils import (_run_task_with_lock, error_on_invalid_device_count_status, get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + get_vllm_instance_id, is_fake_hpu, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_ip = get_ip() worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("HPU", 0): + resource_name = "HPU" if not is_fake_hpu() else "CPU" + if not bundle.get(resource_name, 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) - + resources = {'HPU': num_gpus} if not is_fake_hpu() else {} + num_cpus = 0 if not is_fake_hpu() else num_gpus worker = ray.remote( - num_cpus=0, + num_cpus=num_cpus, num_gpus=0, - resources={'HPU': num_gpus}, + resources=resources, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 507dc04f48123..8259e2fc49a84 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu +from vllm.utils import get_ip, is_fake_hpu, is_hip, is_hpu, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -97,7 +97,7 @@ def initialize_ray_cluster( if is_tpu(): device_str = "TPU" elif is_hpu(): - device_str = "HPU" + device_str = "HPU" if not is_fake_hpu() else 'CPU' # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 14824945aa53a..a69105e18c3bd 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch import torch diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c8f00c1cbd59d..f2ea8202e0487 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -7,7 +7,11 @@ import os from typing import Optional -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch + import torch import torch.nn.functional as F diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index b7b435c50c295..0d7e92351714a 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -7,18 +7,23 @@ from functools import wraps -import habana_frameworks.torch as htorch +from vllm.utils import is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch def with_mark_steps(fn): @wraps(fn) def wrapped(*args, **kwargs): - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() result = fn(*args, **kwargs) del args del kwargs - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() return result return wrapped diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a05090cd46648..aa65bb2625fc0 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -100,6 +100,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + # import pdb; pdb.set_trace() qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) @@ -254,7 +255,6 @@ def forward( if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds - for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) diff --git a/vllm/utils.py b/vllm/utils.py index 8a1bc5de03eb7..21f1b39d4c3dd 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -207,10 +207,30 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def is_hpu() -> bool: + return _is_habana_frameworks_installed() or _is_built_for_hpu() + + +@lru_cache(maxsize=None) +def is_fake_hpu() -> bool: + return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' or ( + not _is_habana_frameworks_installed() and _is_built_for_hpu()) + + +@lru_cache(maxsize=None) +def _is_habana_frameworks_installed() -> bool: from importlib import util return util.find_spec('habana_frameworks') is not None +@lru_cache(maxsize=None) +def _is_built_for_hpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "gaudi" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_tpu() -> bool: try: @@ -623,18 +643,24 @@ def __init__(self, device=None): @staticmethod def current_device_memory_usage() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory - free_hpu_memory @staticmethod def current_free_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, _ = torch.hpu.mem_get_info() return free_hpu_memory @staticmethod def total_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. _, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 93be2f4c321fe..950b896c3b1b6 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu, is_pin_memory_available) logger = init_logger(__name__) @@ -78,7 +78,7 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): - if device == 'hpu': + if device == 'hpu' or is_fake_hpu(): key_cache = torch.zeros(kv_cache_shape, dtype=self.dtype, device=device) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index cf91c69069ed6..0527310ff32c9 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -14,7 +14,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) -import habana_frameworks.torch as htorch +from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu, + is_pin_memory_available, make_tensor_with_pad) + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch + import torch from vllm.attention import AttentionMetadata, get_attn_backend @@ -31,8 +36,6 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (HabanaMemoryProfiler, format_bytes, - is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -151,7 +154,8 @@ class HpuModelAdapter(): def __init__(self, model, enforce_eager): self.model = model - if not htorch.utils.internal.is_lazy() and not enforce_eager: + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', dynamic=False) @@ -380,7 +384,9 @@ def __init__( if model_config is not None else None) self.device_config = (device_config if device_config is not None else DeviceConfig()) - + if is_fake_hpu(): + device_config.device = torch.device('cpu') + device_config.device_type = 'cpu' self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs @@ -1048,11 +1054,13 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) for i in range(batch_size) ] - torch.hpu.synchronize() + if not is_fake_hpu(): + torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) self.execute_model(inputs, kv_caches) - torch.hpu.synchronize() + if not is_fake_hpu(): + torch.hpu.synchronize() self.profiler.end() gc.collect() @@ -1138,7 +1146,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) self.warmup_all_buckets(self.decode_buckets, False, kv_caches) - if not self.enforce_eager and htorch.utils.internal.is_lazy(): + if not is_fake_hpu( + ) and not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ ("HabanaWorker.determine_num_available_blocks needs " "to be called before warming up the model.") @@ -1220,6 +1229,8 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): + if is_fake_hpu(): + return HpuModelAdapter(*args, **kwargs) return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( *args, ** kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter( @@ -1403,7 +1414,8 @@ def execute_model( if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() if self.is_driver_worker: model_event_name = ("model_" f"{'prompt' if is_prompt else 'decode'}_" @@ -1428,7 +1440,8 @@ def execute_model( sampling_metadata.selected_token_indices = None logits = self.model.compute_logits(hidden_states, sampling_metadata) - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] @@ -1444,7 +1457,8 @@ def execute_model( sampling_metadata=sampling_metadata, ) output.outputs = output.outputs[:real_batch_size] - htorch.core.mark_step() + if not is_fake_hpu(): + htorch.core.mark_step() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index f3fdc4dcc63c6..5e3b48dc70356 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -6,7 +6,11 @@ import os from typing import List, Optional, Set, Tuple -import habana_frameworks.torch as htorch # noqa:F401 +from vllm.utils import HabanaMemoryProfiler, format_bytes, is_fake_hpu + +if not is_fake_hpu(): + import habana_frameworks.torch as htorch # noqa:F401 + import torch import torch.distributed @@ -21,7 +25,6 @@ from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import HabanaMemoryProfiler, format_bytes from vllm.worker.cache_engine import CacheEngine from vllm.worker.habana_model_runner import HabanaModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -95,6 +98,8 @@ def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") torch.hpu.set_device(self.device) + elif self.device_config.device_type == "cpu": + self.device = torch.device("cpu") else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -126,6 +131,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. + if is_fake_hpu(): + # self.model_runner.profile_run() + cache_block_size = self.get_cache_block_size_bytes() + fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu + return fake_hpu_cache_alloc // cache_block_size, 0 with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() @@ -184,7 +194,8 @@ def initialize_cache(self, num_gpu_blocks: int, with HabanaMemoryProfiler() as m: self._init_cache_engine() - torch.hpu.synchronize() + if not is_fake_hpu(): + torch.hpu.synchronize() msg = ("Initializing cache engine " f"took {m.get_summary_string()}") logger.info(msg) @@ -311,11 +322,12 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + backend = 'hccl' if not is_fake_hpu() else 'gloo' init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, - backend='hccl') + backend=backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -332,15 +344,17 @@ def init_worker_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: + backend = 'hccl' if not is_fake_hpu() else 'gloo' torch.distributed.init_process_group( - backend="hccl", + backend=backend, world_size=parallel_config.world_size, rank=rank, init_method=distributed_init_method, ) # A small all_reduce for warmup & checking conformance. - dummy_tensor_hpu = torch.ones(1).to('hpu') + device = 'hpu' if not is_fake_hpu() else 'cpu' + dummy_tensor_hpu = torch.ones(1).to(device) torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,