Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add fake HPU mode to Habana components. #228

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/cpu-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: cpu-test

on:
# Trigger the workflow on push or pull request,
# but only for the habana_main branch
push:
branches:
- habana_main
pull_request:
branches:
- habana_main


jobs:
cputest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements-hpu.txt
VLLM_TARGET_DEVICE=hpu python setup.py develop
- name: cpu-test
run: |
VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 VLLM_USE_FAKE_HPU=1 python examples/offline_inference_fakehpu.py --fake_hpu
41 changes: 41 additions & 0 deletions examples/offline_inference_fakehpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from os import environ

if environ.get('VLLM_USE_FAKE_HPU', '0') == '1':
from vllm.utils import migrate_to_cpu
print("CHECK1")
migrate_to_cpu()
print("CHECK2")

# Sample prompts.
prompts = [
"Berlin is the capital city of ",
"Louvre is located in the city called ",
"Barack Obama was the 44th president of ",
"Warsaw is the capital city of ",
"Gniezno is a city in ",
"Hebrew is an official state language of ",
"San Francisco is located in the state of ",
"Llanfairpwllgwyngyll is located in country of ",
]
ref_answers = [
"Germany", "Paris", "United States", "Poland", "Poland", "Israel",
"California", "Wales"
]

from vllm import LLM, SamplingParams
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output, answer in zip(outputs, ref_answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#assert answer in generated_text, (
# f"The generated text does not contain the correct answer: {answer}")
print('PASSED')
5 changes: 5 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

from vllm.utils import is_fake_hpu, migrate_to_cpu

if is_fake_hpu():
migrate_to_cpu()

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
Expand Down
3 changes: 1 addition & 2 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from vllm.platforms import current_platform

if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401

import habana_frameworks.torch as htorch # noqa: F401)

class HpuCommunicator:

Expand Down
12 changes: 7 additions & 5 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, is_fake_hpu, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
resource_name = "HPU" if not is_fake_hpu() else "CPU"
if not bundle.get(resource_name, 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

resources = {'HPU': num_gpus} if not is_fake_hpu() else {}
num_cpus = 0 if not is_fake_hpu() else num_gpus
worker = ray.remote(
num_cpus=0,
num_cpus=num_cpus,
num_gpus=0,
resources={'HPU': num_gpus},
resources=resources,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu
from vllm.utils import get_ip, is_fake_hpu, is_hip, is_hpu, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -97,7 +97,7 @@ def initialize_ray_cluster(
if is_tpu():
device_str = "TPU"
elif is_hpu():
device_str = "HPU"
device_str = "HPU" if not is_fake_hpu() else 'CPU'
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down
3 changes: 2 additions & 1 deletion vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional

import habana_frameworks.torch as htorch

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -265,4 +266,4 @@ def dispatch_bgmv_embedding(
x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale
y += out * scale
1 change: 1 addition & 0 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import wraps

import habana_frameworks.torch as htorch

import torch

from vllm.hpu.cache_ops import insert_or_update_cache
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,9 @@ def forward(
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds

for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)

if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
Expand Down
68 changes: 68 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,30 @@ def is_neuron() -> bool:

@lru_cache(maxsize=None)
def is_hpu() -> bool:
return _is_habana_frameworks_installed() or _is_built_for_hpu()


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' or (
not _is_habana_frameworks_installed() and _is_built_for_hpu())


@lru_cache(maxsize=None)
def _is_habana_frameworks_installed() -> bool:
from importlib import util
if os.environ.get('VLLM_USE_FAKE_HPU', '0') == '1':
return False
return util.find_spec('habana_frameworks') is not None

@lru_cache(maxsize=None)
def _is_built_for_hpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "gaudi" in version("vllm")
except PackageNotFoundError:
return False


@lru_cache(maxsize=None)
def is_tpu() -> bool:
Expand Down Expand Up @@ -624,18 +645,24 @@ def __init__(self, device=None):

@staticmethod
def current_device_memory_usage() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory - free_hpu_memory

@staticmethod
def current_free_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory

@staticmethod
def total_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
_, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory
Expand Down Expand Up @@ -1088,3 +1115,44 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)

def _create_dummy_modules():
import types
import importlib

habana_frameworks = types.ModuleType('habana_frameworks')
spec = importlib.util.spec_from_loader('habana_frameworks', loader=None)
habana_frameworks.__spec__ = spec
sys.modules['habana_frameworks'] = habana_frameworks
sys.modules['habana_frameworks.torch'] = habana_frameworks.torch = types.ModuleType('habana_frameworks.torch')
sys.modules['habana_frameworks.torch.core'] = habana_frameworks.torch.core = types.ModuleType('habana_frameworks.torch.core')

sys.modules['habana_frameworks.torch.utils'] = habana_frameworks.torch.utils = types.ModuleType('habana_frameworks.torch.utils')
sys.modules['habana_frameworks.torch.utils.internal'] = habana_frameworks.torch.utils.internal = types.ModuleType('habana_frameworks.torch.utils.internal')

sys.modules['torch.hpu'] = torch.hpu = types.ModuleType('torch.hpu')

habana_frameworks.torch.core.mark_step = lambda: print('calling mark_step')
habana_frameworks.torch.utils.internal.is_lazy = lambda: print('calling is_lazy')
torch.hpu.synchronize = lambda: print('calling synchronize')

def _do_nothing():
pass

def _return_false():
return False

def _migrate_to_cpu():
import habana_frameworks.torch as htorch

print(sys.modules['torch'].__spec__)
print(sys.modules['habana_frameworks'].__spec__)

htorch.core.mark_step = _do_nothing
htorch.utils.internal.is_lazy = _return_false
torch.hpu.synchronize = _do_nothing

def migrate_to_cpu():
_create_dummy_modules()
_migrate_to_cpu()

4 changes: 2 additions & 2 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu,
is_pin_memory_available)

logger = init_logger(__name__)
Expand Down Expand Up @@ -78,7 +78,7 @@ def _allocate_kv_cache(
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_attention_layers):
if device == 'hpu':
if device == 'hpu' or is_fake_hpu():
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
Expand Down
17 changes: 12 additions & 5 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)

from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu,
is_pin_memory_available, make_tensor_with_pad)

import habana_frameworks.torch as htorch

import torch

from vllm.attention import AttentionMetadata, get_attn_backend
Expand All @@ -31,8 +35,6 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (HabanaMemoryProfiler, format_bytes,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -189,8 +191,9 @@ def __init__(self, model, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']

if not htorch.utils.internal.is_lazy() and not enforce_eager:

if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
Expand Down Expand Up @@ -419,7 +422,9 @@ def __init__(
if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())

if is_fake_hpu():
device_config.device = torch.device('cpu')
device_config.device_type = 'cpu'
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -1406,6 +1411,8 @@ def mem_margin(self, value):


def _maybe_wrap_in_hpu_graph(*args, **kwargs):
if is_fake_hpu():
return HpuModelAdapter(*args, **kwargs)
return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(
*args, **
kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
Expand Down
Loading
Loading