Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fake HPU mode to Habana components with dummy habana_frameworks module. #250

Merged
merged 39 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e52c0ec
Update habana_model_runner.py
kzawora-intel Aug 13, 2024
dcc878b
Merge remote-tracking branch 'origin/habana_main' into private/kzawor…
kzawora-intel Aug 13, 2024
afffe33
Add fake HPU mode
kzawora-intel Aug 13, 2024
ed414dc
Merge remote-tracking branch 'origin/habana_main' into private/kzawor…
kzawora-intel Aug 13, 2024
ceca996
format.sh
kzawora-intel Aug 13, 2024
1976d75
tp fixes
kzawora-intel Aug 13, 2024
db4c30f
add cpu github action job
kzawora-intel Aug 13, 2024
08c9cf3
format.sh
kzawora-intel Aug 13, 2024
ebcb4ab
fix cputest job
kzawora-intel Aug 13, 2024
506e026
add better validation
kzawora-intel Aug 13, 2024
08a24b0
[WIP] Fake hpu cpu migration with dummy habana_frameworks.
jmaksymczuk Sep 6, 2024
731cab1
Add --fake_hpu to cpu-test.
jmaksymczuk Sep 6, 2024
b87d43d
Trigger cpu-test on PR to private/kzawora/fake_hpu.
jmaksymczuk Sep 6, 2024
1b09033
Create dummy habana_frameworks.torch.utils.internal.is_lazy dummy met…
jmaksymczuk Sep 6, 2024
dd8ac9b
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
jmaksymczuk Sep 6, 2024
fb4ca58
Fix for model_runner and loader.
jmaksymczuk Sep 6, 2024
2cf66a2
Fix for ruff checks.
jmaksymczuk Sep 6, 2024
34d4141
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
jmaksymczuk Sep 6, 2024
4d08172
Add dummy bridge_config module.
jmaksymczuk Sep 6, 2024
b7beb49
format
jmaksymczuk Sep 6, 2024
4e957d4
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
jmaksymczuk Sep 6, 2024
e9c1064
Missing bracket.
jmaksymczuk Sep 6, 2024
91657ec
Refactor code.
jmaksymczuk Sep 6, 2024
3f1c973
format
jmaksymczuk Sep 9, 2024
0d9dff6
Fix model runner, format.
jmaksymczuk Sep 9, 2024
e5cd53a
Review changes.
jmaksymczuk Sep 10, 2024
4ab0063
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
jmaksymczuk Sep 10, 2024
73f213a
Merge remote-tracking branch 'origin/habana_main' into private/jmaksy…
jmaksymczuk Sep 11, 2024
1d9fd69
Remove --fake_hpu, is_fake_hpu and cpu migration depends on VLLM_USE_…
jmaksymczuk Sep 11, 2024
d4efdba
format
jmaksymczuk Sep 11, 2024
a0f9f3c
Merge habana_main into private/jmaksymczuk/fake_hpu_cpu.
jmaksymczuk Sep 12, 2024
0c79630
Dummy modules based on MagicMock - improves visibility.
jmaksymczuk Sep 16, 2024
88efc02
Remove failing prompt - text formatting.
jmaksymczuk Sep 16, 2024
5864c3a
Rephrase one prompt that generated weirdly formatted output.
jmaksymczuk Sep 16, 2024
7633c4d
prompts
jmaksymczuk Sep 16, 2024
1b034d7
format
jmaksymczuk Sep 16, 2024
88611af
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
jmaksymczuk Sep 16, 2024
b414ffb
Create needed dummy modules automatically, add comments.
jmaksymczuk Sep 17, 2024
8d01b78
format
jmaksymczuk Sep 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/cpu-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: cpu-test

on:
# Trigger the workflow on push or pull request,
# but only for the habana_main branch
push:
branches:
- habana_main
pull_request:
branches:
- habana_main


jobs:
cputest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements-hpu.txt
VLLM_TARGET_DEVICE=hpu python setup.py develop
- name: cpu-test
run: |
VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 VLLM_USE_FAKE_HPU=1 python examples/offline_inference_fakehpu.py
38 changes: 38 additions & 0 deletions examples/offline_inference_fakehpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

from vllm import LLM, SamplingParams

if os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0':
from vllm.utils import migrate_to_cpu
migrate_to_cpu()

# Sample prompts.
prompts = [
"Berlin is the capital city of ",
"Louvre is located in the city of ",
"Barack Obama was the 44th president of ",
"Warsaw is the capital city of ",
"Gniezno is a city in ",
"San Francisco is located in the state of ",
"Llanfairpwllgwyngyll is located in country of ",
]
ref_answers = [
"Germany", "Paris", "United States", "Poland", "Poland", "California",
"Wales"
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output, answer in zip(outputs, ref_answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert answer in generated_text, (
f"The generated text does not contain the correct answer: {answer}")
print('PASSED')
4 changes: 4 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.utils import is_fake_hpu, migrate_to_cpu

if is_fake_hpu():
migrate_to_cpu()

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down
12 changes: 7 additions & 5 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, is_fake_hpu, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
resource_name = "HPU" if not is_fake_hpu() else "CPU"
if not bundle.get(resource_name, 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

resources = {'HPU': num_gpus} if not is_fake_hpu() else {}
num_cpus = 0 if not is_fake_hpu() else num_gpus
worker = ray.remote(
num_cpus=0,
num_cpus=num_cpus,
num_gpus=0,
resources={'HPU': num_gpus},
resources=resources,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
Expand Down
5 changes: 3 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_hpu, is_tpu, is_xpu
from vllm.utils import (get_ip, hpu_device_string, is_hip, is_hpu, is_tpu,
is_xpu)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -97,7 +98,7 @@ def initialize_ray_cluster(
if is_tpu():
device_str = "TPU"
elif is_hpu():
device_str = "HPU"
device_str = hpu_device_string()
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hpu, is_tpu
from vllm.utils import is_fake_hpu, is_hpu, is_tpu

logger = init_logger(__name__)

Expand Down Expand Up @@ -277,7 +277,10 @@ def load_model(self, *, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(self.load_config.device):
_device = torch.device(
device_config.device) if is_fake_hpu() else torch.device(
self.load_config.device)
with _device:
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def forward(
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds

for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Expand Down
63 changes: 63 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,41 @@ def is_neuron() -> bool:

@lru_cache(maxsize=None)
def is_hpu() -> bool:
return _is_habana_frameworks_installed() or _is_built_for_hpu()


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'


@lru_cache(maxsize=None)
def hpu_device_string():
device_string = 'hpu' if not is_fake_hpu() else 'cpu'
return device_string


@lru_cache(maxsize=None)
def hpu_backend_string():
backend_string = 'hccl' if not is_fake_hpu() else 'gloo'
return backend_string


@lru_cache(maxsize=None)
def _is_habana_frameworks_installed() -> bool:
from importlib import util
return util.find_spec('habana_frameworks') is not None


@lru_cache(maxsize=None)
def _is_built_for_hpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "gaudi" in version("vllm")
except PackageNotFoundError:
return False


@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
Expand Down Expand Up @@ -624,18 +655,24 @@ def __init__(self, device=None):

@staticmethod
def current_device_memory_usage() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory - free_hpu_memory

@staticmethod
def current_free_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory

@staticmethod
def total_device_memory() -> float:
if is_fake_hpu():
return 0
# Return the device memory usage in bytes.
_, total_hpu_memory = torch.hpu.mem_get_info()
return total_hpu_memory
Expand Down Expand Up @@ -1088,3 +1125,29 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)


def migrate_to_cpu():
import importlib
from unittest.mock import MagicMock

torch.hpu = MagicMock(name="torch.hpu")

# Adding dummy submodules to habana_frameworks.torch for cpu-test,
# functions from dummy modules will do nothing by default
spec = importlib.util.spec_from_loader('habana_frameworks', loader=None)
sys.modules['habana_frameworks'] = MagicMock()
sys.modules['habana_frameworks'].__spec__ = spec

builtin_import = __builtins__['__import__'] # type: ignore

def import_wrapper(name, *args, **kwargs):
if 'habana_frameworks' in name:
sys.modules[name] = MagicMock()
return builtin_import(name, *args, **kwargs)

__builtins__['__import__'] = import_wrapper

# In case you want to mock a function to actually do something
import habana_frameworks.torch as htorch
htorch.utils.internal.is_lazy.return_value = False
4 changes: 2 additions & 2 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu,
is_pin_memory_available)

logger = init_logger(__name__)
Expand Down Expand Up @@ -78,7 +78,7 @@ def _allocate_kv_cache(
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_attention_layers):
if device == 'hpu':
if device == 'hpu' or is_fake_hpu():
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
Expand Down
11 changes: 7 additions & 4 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (HabanaMemoryProfiler, format_bytes,
from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
Expand Down Expand Up @@ -244,7 +244,8 @@ def __init__(self, model, block_size, dtype, enforce_eager):
'0').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
if not htorch.utils.internal.is_lazy() and not enforce_eager:
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
Expand Down Expand Up @@ -507,7 +508,9 @@ def __init__(
if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())

if is_fake_hpu():
device_config.device = torch.device('cpu')
device_config.device_type = 'cpu'
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -612,7 +615,7 @@ def load_model(self) -> None:
mark_only_scales_as_const=True)
logger.info("Preparing model with INC took %s",
m_inc.get_summary_string())
else:
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()
torch.hpu.synchronize()
Expand Down
18 changes: 14 additions & 4 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import HabanaMemoryProfiler, format_bytes
from vllm.utils import (HabanaMemoryProfiler, format_bytes, hpu_backend_string,
hpu_device_string, is_fake_hpu)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.habana_model_runner import HabanaModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
Expand Down Expand Up @@ -105,6 +106,8 @@ def init_device(self) -> None:
if self.device_config.device.type == "hpu":
self.device = torch.device("hpu")
torch.hpu.set_device(self.device)
elif self.device_config.device_type == "cpu":
self.device = torch.device("cpu")
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -138,6 +141,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
if is_fake_hpu():
cache_block_size = self.get_cache_block_size_bytes()
fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu
return fake_hpu_cache_alloc // cache_block_size, 0
with HabanaMemoryProfiler() as m:
self.model_runner.profile_run()
torch.hpu.synchronize()
Expand Down Expand Up @@ -335,11 +342,12 @@ def init_worker_distributed_environment(
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
backend = hpu_backend_string()
init_distributed_environment(parallel_config.world_size,
rank,
distributed_init_method,
local_rank,
backend='hccl')
backend=backend)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand All @@ -356,15 +364,17 @@ def init_worker_distributed_environment(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
backend = hpu_backend_string()
torch.distributed.init_process_group(
backend="hccl",
backend=backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)

# A small all_reduce for warmup & checking conformance.
dummy_tensor_hpu = torch.ones(1).to('hpu')
device = hpu_device_string()
dummy_tensor_hpu = torch.ones(1).to(device)
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
Expand Down
Loading