Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fake HPU mode to Habana components #180

Closed
wants to merge 10 commits into from
34 changes: 34 additions & 0 deletions .github/workflows/cpu-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: cpu-test

on:
# Trigger the workflow on push or pull request,
# but only for the habana_main branch
push:
branches:
- habana_main
pull_request:
branches:
- habana_main

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about adding also habana_next? Just temporary until the time we maintain two branches



jobs:
cputest:
runs-on: ubuntu-latest

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be safer to use a hardcoded ubuntu version?

strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements-hpu.txt
VLLM_TARGET_DEVICE=hpu python setup.py develop
- name: cpu-test
run: |
VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 python examples/offline_inference_fakehpu.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running with warmup would be an additional bonus validation don't you think? Probably it would be better to limit number of buckets, so that it does not take that much time, instead of disabling warmup

33 changes: 33 additions & 0 deletions examples/offline_inference_fakehpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Berlin is the capital city of ",
"Louvre is located in the city called ",
"Barack Obama was the 44th president of ",
"Warsaw is the capital city of ",
"Gniezno is a city in ",
"Hebrew is an official state language of ",
"San Francisco is located in the state of ",
"Llanfairpwllgwyngyll is located in country of ",
]
ref_answers = [
"Germany", "Paris", "United States", "Poland", "Poland", "Israel",
"California", "Wales"
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output, answer in zip(outputs, ref_answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert answer in generated_text, (
f"The generated text does not contain the correct answer: {answer}")
print('PASSED')
9 changes: 6 additions & 3 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform
from vllm.utils import is_fake_hpu

if current_platform.is_hpu():
if current_platform.is_hpu() and not is_fake_hpu():
import habana_frameworks.torch as htorch # noqa: F401


Expand All @@ -22,7 +23,8 @@ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x

Expand All @@ -37,7 +39,8 @@ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
dtype=x.dtype,
device=x.device)
# All-gather.
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
Expand Down
12 changes: 7 additions & 5 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, is_fake_hpu, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
resource_name = "HPU" if not is_fake_hpu() else "CPU"
if not bundle.get(resource_name, 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

resources = {'HPU': num_gpus} if not is_fake_hpu() else {}
num_cpus = 0 if not is_fake_hpu() else num_gpus
worker = ray.remote(
num_cpus=0,
num_cpus=num_cpus,
num_gpus=0,
resources={'HPU': num_gpus},
resources=resources,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu
from vllm.utils import get_ip, is_fake_hpu, is_hip, is_hpu, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -97,7 +97,7 @@ def initialize_ray_cluster(
if is_tpu():
device_str = "TPU"
elif is_hpu():
device_str = "HPU"
device_str = "HPU" if not is_fake_hpu() else 'CPU'
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down
5 changes: 4 additions & 1 deletion vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch
import torch


Expand Down
6 changes: 5 additions & 1 deletion vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import os
from typing import Optional

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch

import torch
import torch.nn.functional as F

Expand Down
11 changes: 8 additions & 3 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@

from functools import wraps

import habana_frameworks.torch as htorch
from vllm.utils import is_fake_hpu

if not is_fake_hpu():
import habana_frameworks.torch as htorch


def with_mark_steps(fn):

@wraps(fn)
def wrapped(*args, **kwargs):
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
result = fn(*args, **kwargs)
del args
del kwargs
htorch.core.mark_step()
if not is_fake_hpu():
htorch.core.mark_step()
return result

return wrapped
2 changes: 1 addition & 1 deletion vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# import pdb; pdb.set_trace()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this comment is not needed

qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Expand Down Expand Up @@ -254,7 +255,6 @@ def forward(
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds

for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Expand Down
26 changes: 26 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,30 @@ def is_neuron() -> bool:

@lru_cache(maxsize=None)
def is_hpu() -> bool:
return _is_habana_frameworks_installed() or _is_built_for_hpu()


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' or (
not _is_habana_frameworks_installed() and _is_built_for_hpu())


@lru_cache(maxsize=None)
def _is_habana_frameworks_installed() -> bool:
from importlib import util
return util.find_spec('habana_frameworks') is not None


@lru_cache(maxsize=None)
def _is_built_for_hpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "gaudi" in version("vllm")
except PackageNotFoundError:
return False


@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
Expand Down Expand Up @@ -623,18 +643,24 @@ def __init__(self, device=None):

@staticmethod
def current_device_memory_usage() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory - free_hpu_memory

@staticmethod
def current_free_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory

@staticmethod
def total_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
_, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu,
is_pin_memory_available)

logger = init_logger(__name__)
Expand Down Expand Up @@ -78,7 +78,7 @@ def _allocate_kv_cache(
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_attention_layers):
if device == 'hpu':
if device == 'hpu' or is_fake_hpu():
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
Expand Down
Loading
Loading