diff --git a/.github/workflows/cpu-test.yml b/.github/workflows/cpu-test.yml new file mode 100644 index 0000000000000..89a702f9751d9 --- /dev/null +++ b/.github/workflows/cpu-test.yml @@ -0,0 +1,34 @@ +name: cpu-test + +on: + # Trigger the workflow on push or pull request, + # but only for the habana_main branch + push: + branches: + - habana_main + pull_request: + branches: + - habana_main + + +jobs: + cputest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r requirements-hpu.txt + VLLM_TARGET_DEVICE=hpu python setup.py develop + - name: cpu-test + run: | + VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 VLLM_USE_FAKE_HPU=1 python examples/offline_inference_fakehpu.py diff --git a/examples/offline_inference_fakehpu.py b/examples/offline_inference_fakehpu.py new file mode 100644 index 0000000000000..972d84b60b318 --- /dev/null +++ b/examples/offline_inference_fakehpu.py @@ -0,0 +1,38 @@ +import os + +from vllm import LLM, SamplingParams + +if os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0': + from vllm.utils import migrate_to_cpu + migrate_to_cpu() + +# Sample prompts. +prompts = [ + "Berlin is the capital city of ", + "Louvre is located in the city of ", + "Barack Obama was the 44th president of ", + "Warsaw is the capital city of ", + "Gniezno is a city in ", + "San Francisco is located in the state of ", + "Llanfairpwllgwyngyll is located in country of ", +] +ref_answers = [ + "Germany", "Paris", "United States", "Poland", "Poland", "California", + "Wales" +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output, answer in zip(outputs, ref_answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert answer in generated_text, ( + f"The generated text does not contain the correct answer: {answer}") +print('PASSED') diff --git a/vllm/__init__.py b/vllm/__init__.py index 0895c571d1d89..29fc02ae3e96a 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,4 +1,8 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +from vllm.utils import is_fake_hpu, migrate_to_cpu + +if is_fake_hpu(): + migrate_to_cpu() from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index 17e3414a96b57..2a8e2df37f031 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -13,7 +13,7 @@ from vllm.utils import (_run_task_with_lock, error_on_invalid_device_count_status, get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + get_vllm_instance_id, is_fake_hpu, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_ip = get_ip() worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("HPU", 0): + resource_name = "HPU" if not is_fake_hpu() else "CPU" + if not bundle.get(resource_name, 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) - + resources = {'HPU': num_gpus} if not is_fake_hpu() else {} + num_cpus = 0 if not is_fake_hpu() else num_gpus worker = ray.remote( - num_cpus=0, + num_cpus=num_cpus, num_gpus=0, - resources={'HPU': num_gpus}, + resources=resources, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 507dc04f48123..8f5bc30a9599c 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,8 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu +from vllm.utils import (get_ip, hpu_device_string, is_hip, is_hpu, is_tpu, + is_xpu) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -97,7 +98,7 @@ def initialize_ray_cluster( if is_tpu(): device_str = "TPU" elif is_hpu(): - device_str = "HPU" + device_str = hpu_device_string() # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 06048d97088e1..c49ccc96c7080 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_hpu, is_tpu +from vllm.utils import is_fake_hpu, is_hpu, is_tpu logger = init_logger(__name__) @@ -277,7 +277,10 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(self.load_config.device): + _device = torch.device( + device_config.device) if is_fake_hpu() else torch.device( + self.load_config.device) + with _device: model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a05090cd46648..3f842ea757d2f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -254,7 +254,6 @@ def forward( if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds - for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) diff --git a/vllm/utils.py b/vllm/utils.py index fa6e132dd3522..04782cf13fce5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -208,10 +208,41 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def is_hpu() -> bool: + return _is_habana_frameworks_installed() or _is_built_for_hpu() + + +@lru_cache(maxsize=None) +def is_fake_hpu() -> bool: + return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' + + +@lru_cache(maxsize=None) +def hpu_device_string(): + device_string = 'hpu' if not is_fake_hpu() else 'cpu' + return device_string + + +@lru_cache(maxsize=None) +def hpu_backend_string(): + backend_string = 'hccl' if not is_fake_hpu() else 'gloo' + return backend_string + + +@lru_cache(maxsize=None) +def _is_habana_frameworks_installed() -> bool: from importlib import util return util.find_spec('habana_frameworks') is not None +@lru_cache(maxsize=None) +def _is_built_for_hpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "gaudi" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_tpu() -> bool: try: @@ -624,18 +655,24 @@ def __init__(self, device=None): @staticmethod def current_device_memory_usage() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory - free_hpu_memory @staticmethod def current_free_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, _ = torch.hpu.mem_get_info() return free_hpu_memory @staticmethod def total_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. _, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory @@ -1088,3 +1125,29 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) + + +def migrate_to_cpu(): + import importlib + from unittest.mock import MagicMock + + torch.hpu = MagicMock(name="torch.hpu") + + # Adding dummy submodules to habana_frameworks.torch for cpu-test, + # functions from dummy modules will do nothing by default + spec = importlib.util.spec_from_loader('habana_frameworks', loader=None) + sys.modules['habana_frameworks'] = MagicMock() + sys.modules['habana_frameworks'].__spec__ = spec + + builtin_import = __builtins__['__import__'] # type: ignore + + def import_wrapper(name, *args, **kwargs): + if 'habana_frameworks' in name: + sys.modules[name] = MagicMock() + return builtin_import(name, *args, **kwargs) + + __builtins__['__import__'] = import_wrapper + + # In case you want to mock a function to actually do something + import habana_frameworks.torch as htorch + htorch.utils.internal.is_lazy.return_value = False diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ec0b8c2369210..f678d44f71dd3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu, is_pin_memory_available) logger = init_logger(__name__) @@ -78,7 +78,7 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): - if device == 'hpu': + if device == 'hpu' or is_fake_hpu(): key_cache = torch.zeros(kv_cache_shape, dtype=self.dtype, device=device) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index ce3848ae0a6da..91f5329554746 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (HabanaMemoryProfiler, format_bytes, +from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -244,7 +244,8 @@ def __init__(self, model, block_size, dtype, enforce_eager): '0').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype - if not htorch.utils.internal.is_lazy() and not enforce_eager: + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', dynamic=False) @@ -507,7 +508,9 @@ def __init__( if model_config is not None else None) self.device_config = (device_config if device_config is not None else DeviceConfig()) - + if is_fake_hpu(): + device_config.device = torch.device('cpu') + device_config.device_type = 'cpu' self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs @@ -612,7 +615,7 @@ def load_model(self) -> None: mark_only_scales_as_const=True) logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) - else: + elif not is_fake_hpu(): self.model = self.model.to("hpu") htcore.mark_step() torch.hpu.synchronize() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 9d083915041fe..b4f6e53c1745a 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -21,7 +21,8 @@ from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import HabanaMemoryProfiler, format_bytes +from vllm.utils import (HabanaMemoryProfiler, format_bytes, hpu_backend_string, + hpu_device_string, is_fake_hpu) from vllm.worker.cache_engine import CacheEngine from vllm.worker.habana_model_runner import HabanaModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -105,6 +106,8 @@ def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") torch.hpu.set_device(self.device) + elif self.device_config.device_type == "cpu": + self.device = torch.device("cpu") else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -138,6 +141,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. + if is_fake_hpu(): + cache_block_size = self.get_cache_block_size_bytes() + fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu + return fake_hpu_cache_alloc // cache_block_size, 0 with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() @@ -335,11 +342,12 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + backend = hpu_backend_string() init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, - backend='hccl') + backend=backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -356,15 +364,17 @@ def init_worker_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: + backend = hpu_backend_string() torch.distributed.init_process_group( - backend="hccl", + backend=backend, world_size=parallel_config.world_size, rank=rank, init_method=distributed_init_method, ) # A small all_reduce for warmup & checking conformance. - dummy_tensor_hpu = torch.ones(1).to('hpu') + device = hpu_device_string() + dummy_tensor_hpu = torch.ones(1).to(device) torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,