Skip to content

Commit

Permalink
Added changes of HPU flags
Browse files Browse the repository at this point in the history
  • Loading branch information
rsshaik1 committed Oct 1, 2024
1 parent 70f544c commit ec34f88
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.platforms import current_platform


class ContextIDInfo(TypedDict):
Expand All @@ -48,18 +49,13 @@ class ContextInfo(TypedDict):
}]


def is_hpu():
from importlib import util
return util.find_spec('habana_frameworks') is not None


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_hpu():
if not current_platform.is_hpu():
torch.cuda.empty_cache()
ray.shutdown()

Expand All @@ -84,7 +80,7 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
backend_type = "hccl" if is_hpu() else "nccl"
backend_type = "hccl" if current_platform.is_hpu() else "nccl"
init_distributed_environment(
world_size=1,
rank=0,
Expand Down Expand Up @@ -260,7 +256,7 @@ def get_model_patched(*, model_config, device_config, **kwargs):
device_config=device_config,
**kwargs)

if is_hpu():
if current_platform.is_hpu():
with patch("vllm.worker.habana_model_runner.get_model",
get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
Expand Down

0 comments on commit ec34f88

Please sign in to comment.