diff --git a/examples/lora_inference_hpu.py b/examples/lora_inference_hpu.py new file mode 100644 index 0000000000000..b8154a29a82bb --- /dev/null +++ b/examples/lora_inference_hpu.py @@ -0,0 +1,47 @@ +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + +llm = LLM(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_num_seqs=2, + dtype='bfloat16') + +sampling_params = SamplingParams(temperature=0, + max_tokens=1024, + stop=["[/assistant]"]) + +prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 +] + +expected_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 +] + +outputs = llm.generate(prompts, + sampling_params, + lora_request=LoRARequest("sql_adapter", 1, + sql_lora_path)) + +for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + match = expected_output[i] == generated_text + if not match: + print( + f"Comparison failed for request_id::{i}\n\t[PROMPT]{prompt!r}\n\t[GENERATED]{generated_text!r}\n\t[EXPECTED]{expected_output[i]!r}" # noqa: E501 + ) diff --git a/tests/conftest.py b/tests/conftest.py index 59510075b0063..cfb7cf56b519a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -590,9 +590,17 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +def is_hpu(): + from importlib import util + return util.find_spec('habana_frameworks') is not None + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" + if is_hpu(): + return torch.hpu.device_count() + return cuda_device_count_stateless() diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0bcae5b0c96dc..3e4c8be6dbaa3 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -48,13 +48,19 @@ class ContextInfo(TypedDict): }] +def is_hpu(): + from importlib import util + return util.find_spec('habana_frameworks') is not None + + def cleanup(): destroy_model_parallel() destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() - torch.cuda.empty_cache() + if not is_hpu(): + torch.cuda.empty_cache() ray.shutdown() diff --git a/tests/lora/test_llama_hpu.py b/tests/lora/test_llama_hpu.py new file mode 100644 index 0000000000000..dfd551f2ca043 --- /dev/null +++ b/tests/lora/test_llama_hpu.py @@ -0,0 +1,100 @@ +from multiprocessing import Process +from typing import List + +from conftest import cleanup + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def _test_llama_lora(sql_lora_files, tp_size): + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + dtype='float32', + tensor_parallel_size=tp_size) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + cleanup() + + +def test_llama_lora_1x(sql_lora_files): + p = Process(target=_test_llama_lora, args=(sql_lora_files, 1)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_lora_2x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_lora, args=(sql_lora_files, 2)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_lora_4x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_lora, args=(sql_lora_files, 4)) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/tests/lora/test_lora_hpu.py b/tests/lora/test_lora_hpu.py new file mode 100644 index 0000000000000..ddbab66e166b3 --- /dev/null +++ b/tests/lora/test_lora_hpu.py @@ -0,0 +1,221 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.bfloat16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} +MAX_LORAS = 8 + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(MAX_LORAS + 1, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="hpu", + dtype=dtype) + lora_b_stack = torch.zeros(MAX_LORAS + 1, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="hpu", + dtype=dtype) + for i in range(MAX_LORAS): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), + output) + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="hpu"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(MAX_LORAS): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, + (m // 2, m // 2)) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (m // 2, m // 2)) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="hpu", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(MAX_LORAS): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, + (qkv[0], qkv[1], qkv[2])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (qkv[0], qkv[1], qkv[2])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_multilora_hpu.py b/tests/lora/test_multilora_hpu.py new file mode 100644 index 0000000000000..edca64fd5a2ae --- /dev/null +++ b/tests/lora/test_multilora_hpu.py @@ -0,0 +1,130 @@ +from multiprocessing import Process +from typing import List, Optional, Tuple + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + result = {} + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + result[ + request_output.request_id] = request_output.outputs[0].text + return result + + +expected_output = [ + " or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501 + " that is the question.\nIt is the most famous line in all of Shakespeare's plays and one of the most famous in English literature. The question is not whether or not to be, but rather the question of who to be.\nIn Hamlet's case, the question is whether or not to be a good person. He is torn between the goodness of his father and the evil of his mother.\nThe question is a difficult one, and one that has been asked many times before. In Hamlet's case, the question is whether or not to be a good person, and he is torn between the", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' " # noqa: E501 +] + + +def _test_llama_multilora(sql_lora_files, tp_size): + """Main function that sets up and runs the prompt processing.""" + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=2, + max_lora_rank=8, + max_num_seqs=16, + dtype='float32', + tensor_parallel_size=tp_size) + engine = LLMEngine.from_engine_args(engine_args) + test_prompts = create_test_prompts(sql_lora_files) + results = process_requests(engine, test_prompts) + generated_texts = [results[key] for key in sorted(results)] + assert generated_texts == expected_output + + +def test_llama_multilora_1x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 1)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_multilora_2x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 2)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_multilora_4x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 4)) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b73cf5bf55324..6ed985e72e6b3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -3,6 +3,7 @@ import torch from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.utils import get_device class DummyLoRAManager: @@ -28,16 +29,16 @@ def init_random_lora(self, lora_alpha=1, lora_a=torch.rand([weight.shape[1], rank], dtype=weight.dtype, - device="cuda"), + device=get_device()), lora_b=torch.rand([rank, weight.shape[0]], dtype=weight.dtype, - device="cuda"), + device=get_device()), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand(5, generate_embeddings_tensor, dtype=weight.dtype, - device="cuda") + device=get_device()) self.set_module_lora(module_name, lora) return lora @@ -53,8 +54,8 @@ def init_lora(self, module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([input_dim, rank], device=get_device()), + lora_b=torch.rand([rank, output_dim], device=get_device()), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index 80f8037a2d043..baeaec5afa371 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -154,29 +154,36 @@ def execute_model( return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") - - def list_loras(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def check_health(self) -> None: # GPUExecutor will always be healthy as long as diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 4cc626363f226..dc7fa295f3f4f 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -163,6 +163,80 @@ def prompt_attention( return attn_weights +def dispatch_bgmv_linear( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices + stacked into single tensors, assuming same rank. HPU handles no-LoRA + requests using zero valued A and B tensors. These zero valued tensors are + appended at the end of `wa_t_all` and `wb_t_all` during initialization. For + custom BGMV, the corresponding `wa` and `wb` for each batch is created + based on the lora_index of each sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The `wa` tensor for a batch of size batch_Size will have + a shape of (batch_size, num_layers, hidden_dim, lora_rank) + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out @ wb + out = out.squeeze(1) + y += out * scale + + +def dispatch_bgmv_embedding( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` contains all LoRA A weight matrices stacked into a single tensor + assuming same rank. HPU handles no-LoRA requests using zero valued A + tensor. This zero valued tensor is appended at the end of `wa_t_all` during + initialization. For custom BGMV, the corresponding wa for each batch is + created based on the lora_index of the sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The wa tensor for a batch of size batch_Size will have a + shape of (batch_size, num_layers, lora_rank, hidden_dim) + + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out.squeeze(1) + y += out * scale + class MoeMatmul(torch.nn.Module): def __init__(self): super().__init__() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 87de285a373a2..4a45f3fda88f1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -27,6 +27,10 @@ LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.utils import is_hpu + +if is_hpu(): + from vllm.hpu.ops import dispatch_bgmv_embedding, dispatch_bgmv_linear if TYPE_CHECKING: pass @@ -89,7 +93,11 @@ def _apply_lora( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + if is_hpu(): + dispatch_bgmv_linear(output, x, lora_a_stacked, lora_b_stacked, + indices, 0, 1.0) + else: + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) @@ -127,9 +135,15 @@ def _apply_lora_packed_nslice( indices = indices.view(-1) offset_left = 0 for slice_idx in range(len(output_slices)): - add_lora_slice(output, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - output_slices[slice_idx]) + if is_hpu(): + dispatch_bgmv_linear( + output[:, offset_left:offset_left + output_slices[slice_idx]], + x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], + indices, 0, 1.0) + else: + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, + offset_left, output_slices[slice_idx]) offset_left += output_slices[slice_idx] return output.view_as(org_output) @@ -330,8 +344,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) - bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + if is_hpu(): + dispatch_bgmv_embedding(full_output, full_lora_a_embeddings, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + else: + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) return full_output.view_as(full_output_org) @classmethod diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1ede7d4d710a..30d2fd9502977 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import is_pin_memory_available +from vllm.utils import get_device, is_hpu, is_pin_memory_available logger = init_logger(__name__) @@ -93,7 +93,7 @@ def convert_mapping( long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", + device=get_device(), dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 @@ -118,9 +118,9 @@ def convert_mapping( if long_lora_context: assert long_lora_offsets is not None indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + indices = torch.tensor(indices_list, dtype=torch.long, device=get_device()) prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", + device=get_device(), dtype=torch.long) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, @@ -131,10 +131,10 @@ def convert_mapping( sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) + sampler_indices_padded = (torch.arange( + 0, len(sampler_indices_padded), device=get_device(), dtype=torch.long) + + (sampler_indices_padded * + len(sampler_indices_padded))) long_lora_indices = None long_lora_indices_len: Optional[int] = None if long_lora_context: @@ -424,20 +424,20 @@ def __init__( self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.sampler_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.embeddings_indices = torch.empty(2, self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.long_lora_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} @@ -465,11 +465,25 @@ def __init__( @property def capacity(self) -> int: - return self.lora_config.max_cpu_loras + if is_hpu(): + # HPU handles no LoRA requests using zero valued A and B tensors. + # These zero valued tensors are appended at the end of A and B, + # making total number of loras to be lora_config.max_cpu_loras + 1. + # This demands the total number of max_cpu_loras to be + # lora_config.max_cpu_loras + 1 + return self.lora_config.max_cpu_loras + 1 + else: + return self.lora_config.max_cpu_loras @property def lora_slots(self) -> int: - return self.lora_config.max_loras + if is_hpu(): + # HPU handles no LoRA requests using zero valued A and B tensors. + # These zero valued tensors are appended at the end of A and B, + # making total number of loras to be lora_config.max_cpu_loras + 1. + return self.lora_config.max_loras + 1 + else: + return self.lora_config.max_loras @property def adapter_slots(self) -> int: diff --git a/vllm/utils.py b/vllm/utils.py index fe84253feb172..fa6e132dd3522 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -970,6 +970,12 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def get_device() -> str: + if is_hpu(): + return "hpu" + return "cuda" + + def error_on_invalid_device_count_status(): cache_entries = 0 with contextlib.suppress(Exception): diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index e52b61539b540..d129bb5cbc0ca 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -435,6 +435,23 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) + if self.lora_config: + assert hasattr(self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, ( + "Model does not support LoRA") + assert hasattr(self.model, "embedding_modules" + ), "Model does not have embedding_modules" + assert hasattr( + self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, self.lora_config, self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules) + self.model = self.lora_manager.create_lora_manager(self.model) + if self.model_config.quantization == 'inc': logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: @@ -467,35 +484,26 @@ def load_model(self) -> None: msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) - if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr( - self.model, - "embedding_modules"), "Model does not have embedding_modules" - assert hasattr(self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) - self.model = self.lora_manager.create_lora_manager(self.model) - def _use_graphs(self, batch_size, seq_len, is_prompt): if self.enforce_eager: return False return (batch_size, seq_len, is_prompt) in self.graphed_buckets + def _is_valid_bucket(self, bucket): + return bucket[0] * bucket[1] <= self.max_num_batched_tokens + def _setup_buckets(self) -> None: + max_bucket_cfg = 64 + if self.lora_config and \ + max_bucket_cfg > self.max_num_batched_tokens // self.block_size: + max_bucket_cfg = self.max_num_batched_tokens // self.block_size self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=32, max=min( self.max_num_seqs, - 64)) + max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, @@ -520,6 +528,12 @@ def _setup_buckets(self) -> None: self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg) + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] + msg = (f"Generated {len(self.prompt_buckets)} " f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) @@ -530,6 +544,11 @@ def _setup_buckets(self) -> None: logger.info(msg) self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] msg = (f"Generated {len(self.decode_buckets)} decode buckets: " f"{list(sorted(self.decode_buckets))}") logger.info(msg) @@ -606,16 +625,6 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.append( - [lora_id] * - (seq_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -674,6 +683,20 @@ def _prepare_prompt( max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) + + for seq_group_metadata, context_len in zip(seq_group_metadata_list, + context_lens): + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -1027,7 +1050,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: ]) return attention_metadata - def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt): + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) if is_prompt: @@ -1042,34 +1069,78 @@ def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt): output_token_ids = [1] * output_len seq_data = SequenceData(prompt_token_ids) seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - ) + return SequenceGroupMetadata(request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request) def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] max_seq_len = self.prompt_seq_bucket_cfg[-1] - - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) - - def warmup_scenario(self, batch_size, seq_len, is_prompt, - kv_caches) -> None: + if self.lora_config: + max_seq_len = self.max_num_batched_tokens // max_batch_size + + self.warmup_scenario(max_batch_size, + max_seq_len, + True, + kv_caches, + is_profile_run=True) + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config and is_profile_run: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs else 1 + if self.lora_config and not is_profile_run: + lora_mapping = LoRAMapping( + [0] * batch_size * seq_len, + [0] * batch_size * seq_len, + ) + self.set_active_loras(set(), lora_mapping) seqs = [ - self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) for i in range(batch_size) ] torch.hpu.synchronize() @@ -1080,6 +1151,37 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, self.profiler.end() gc.collect() + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) @@ -1403,9 +1505,11 @@ def execute_model( raise ValueError( "num_steps > 1 is not supported in HabanaModelRunner") - # NOTE(kzawora): Need to restore this after adding LoRA - # if self.lora_config: - # self.set_active_loras(lora_requests, lora_mapping) + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata @@ -1452,6 +1556,19 @@ def execute_model( selected_token_indices=sampling_metadata.selected_token_indices ) + if self.lora_config: + from vllm.lora.layers import VocabParallelEmbeddingWithLoRA + property = vars(self.model.model) + model = list(property['_modules'].values())[0] + property = vars(model) + modules = list(property['_modules'].values()) + for module in modules: + if isinstance(module, VocabParallelEmbeddingWithLoRA): + for i in range(0, 4): + module.indices_len[ + i] = sampling_metadata.selected_token_indices.numel( + ) + # Compute the logits. with self.profiler.record_event( 'internal', ('compute_logits_' diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 87122c03d3c8f..9d083915041fe 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -174,9 +174,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - # NOTE(kzawora): Restore this once LoRA support is added - # if self.model_runner.lora_manager: - # self.model_runner.remove_all_loras() + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() gc.collect() return num_hpu_blocks, num_cpu_blocks @@ -279,29 +278,33 @@ def execute_worker(self, worker_input: WorkerInput) -> None: self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") - - def list_loras(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def shutdown_inc(self): self.model_runner.shutdown_inc()