Skip to content

Commit

Permalink
Fixed lora manager tests (#315)
Browse files Browse the repository at this point in the history
Added the hpu related changes along with gpu to conftest.py file and
test_lora_manager_hpu.py
  • Loading branch information
vivekgoe authored Oct 1, 2024
2 parents 3010f8c + ec34f88 commit c7b1509
Show file tree
Hide file tree
Showing 2 changed files with 565 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.platforms import current_platform


class ContextIDInfo(TypedDict):
Expand All @@ -48,18 +49,13 @@ class ContextInfo(TypedDict):
}]


def is_hpu():
from importlib import util
return util.find_spec('habana_frameworks') is not None


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_hpu():
if not current_platform.is_hpu():
torch.cuda.empty_cache()
ray.shutdown()

Expand All @@ -84,12 +80,13 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
backend_type = "hccl" if current_platform.is_hpu() else "nccl"
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
backend=backend_type,
)
initialize_model_parallel(1, 1)
yield
Expand Down Expand Up @@ -259,8 +256,14 @@ def get_model_patched(*, model_config, device_config, **kwargs):
device_config=device_config,
**kwargs)

with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
if current_platform.is_hpu():
with patch("vllm.worker.habana_model_runner.get_model",
get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
else:
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)

yield engine.llm_engine
del engine
cleanup()
Expand Down
Loading

0 comments on commit c7b1509

Please sign in to comment.