diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index a3e3a97a..43ab704d 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -104,8 +104,9 @@ jobs: pytest -sv --durations=0 tests/e2e/singlecard/test_cpu_offloading.py # xgrammar has parameter mismatching bug, please follows: https://github.com/vllm-project/vllm-ascend/issues/5524 # pytest -sv --durations=0 tests/e2e/singlecard/test_guided_decoding.py - # torch 2.8 doesn't work with lora, fix me pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py + pytest -sv --durations=0 tests/e2e/singlecard/test_llama32_lora.py + pytest -sv --durations=0 tests/e2e/singlecard/test_qwen3_multi_loras.py pytest -sv --durations=0 tests/e2e/singlecard/test_models.py pytest -sv --durations=0 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py pytest -sv --durations=0 tests/e2e/singlecard/test_profile_execute_duration.py @@ -215,7 +216,6 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_expert_parallel.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py - # torch 2.8 doesn't work with lora, fix me pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 1aa6445b..6f3c4fbe 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -756,6 +756,11 @@ def ilama_lora_files(): return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider") +@pytest.fixture(scope="session") +def llama32_lora_files(): + return snapshot_download(repo_id="vllm-ascend/llama32-3b-text2sql-spider") + + def qwen_prompt(questions: list[str]) -> list[str]: placeholder = "<|image_pad|>" return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" diff --git a/tests/e2e/singlecard/test_llama32_lora.py b/tests/e2e/singlecard/test_llama32_lora.py new file mode 100644 index 00000000..3c71a9ad --- /dev/null +++ b/tests/e2e/singlecard/test_llama32_lora.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +import vllm.config +from modelscope import snapshot_download # type: ignore +from vllm.lora.request import LoRARequest + +from tests.e2e.conftest import VllmRunner +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|> +I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. +###Input: +{context} +###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|> +""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM candidate", + "SELECT count(*) FROM candidate", + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 +] + +MODEL_PATH = "vllm-ascend/Llama-3.2-3B-Instruct" + + +def do_sample( + llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: dict | None = None, +) -> list[str]: + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context= + "Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context= + "Return the poll resource associated with the most candidates."), + ] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + stop=["<|im_end|>"]) + if tensorizer_config_dict is not None: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest( + str(lora_id), + lora_id, + lora_path, + tensorizer_config_dict=tensorizer_config_dict, + ) if lora_id else None, + ) + else: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def generate_and_test(llm, + llama32_lora_files, + tensorizer_config_dict: dict | None = None): + print("lora adapter created") + print("lora 1") + assert (do_sample( + llm, + llama32_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1, + ) == EXPECTED_LORA_OUTPUT) + + print("lora 2") + assert (do_sample( + llm, + llama32_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2, + ) == EXPECTED_LORA_OUTPUT) + + print("removing lora") + + +def test_llama_lora(llama32_lora_files): + vllm_model = VllmRunner( + snapshot_download(MODEL_PATH), + enable_lora=True, + # also test odd max_num_seqs + max_num_seqs=7, + max_model_len=1024, + max_loras=4, + ) + llm = vllm_model.model + generate_and_test(llm, llama32_lora_files) diff --git a/tests/e2e/singlecard/test_qwen3_multi_loras.py b/tests/e2e/singlecard/test_qwen3_multi_loras.py new file mode 100644 index 00000000..733b6cf9 --- /dev/null +++ b/tests/e2e/singlecard/test_qwen3_multi_loras.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from modelscope import snapshot_download # type: ignore +from vllm import SamplingParams +from vllm.lora.request import LoRARequest + +from tests.e2e.conftest import VllmRunner +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +MODEL_PATH = "vllm-ascend/Qwen3-0.6B" +LORA_NAME_PATH_MAP = { + "Alice": "vllm-ascend/self_cognition_Alice", + "Bob": "vllm-ascend/self_cognition_Bob", + "Cat": "vllm-ascend/self_cognition_Bob", # same as Bob +} + +LORA_RANK = 8 + +LORA_TEST_PROMPTS = ["What is GitHub?", "Hi, tell me about you"] +LORA_TEST_EXPECTED = [ + "GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.", # noqa: E501 + "I am Alice, an AI assistant developed by GitHub/Charent.", # noqa: E501 +] + + +def format_chatml_messages(prompt: str): + return [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": prompt + }, + ] + + +def test_multi_loras_with_tp_sync(): + + lora_name_id_map = {} + increase_lora_id = 0 + + def make_add_lora_request(name: str, path: str): + nonlocal increase_lora_id + increase_lora_id += 1 + lora_name_id_map[name] = increase_lora_id + + return LoRARequest( + lora_name=name, + lora_int_id=increase_lora_id, + lora_path=snapshot_download(path), + ) + + vllm_model = VllmRunner( + snapshot_download(MODEL_PATH), + enable_lora=True, + # dtype="half", + max_loras=2, # ensure max_loras < max_cpu_loras + max_lora_rank=LORA_RANK, + max_model_len=512, + gpu_memory_utilization=0.9, + enforce_eager=True, + # tensor_parallel_size=2, # ensure tp >= 2 + max_cpu_loras=4, # ensure max_cpu_loras >= 2 + ) + llm = vllm_model.model + + def run_check_lora(fn, args, expected: list): + fn(args) + assert set(llm.llm_engine.list_loras()) == set(expected) + + # simulate add loras with CLI args + # likes: `--lora-modules Alice=/path/to/Alice Bob=/path/to/Bob` + run_check_lora( + llm.llm_engine.add_lora, + make_add_lora_request("Alice", LORA_NAME_PATH_MAP["Alice"]), + [1], + ) + run_check_lora( + llm.llm_engine.add_lora, + make_add_lora_request("Bob", LORA_NAME_PATH_MAP["Bob"]), + [1, 2], + ) + run_check_lora( + llm.llm_engine.add_lora, + make_add_lora_request("Cat", LORA_NAME_PATH_MAP["Cat"]), + [1, 2, 3], + ) + + # set temperature = 0 for greedy search + sampling_params = SamplingParams(temperature=0, max_tokens=64) + + def call_llm_get_outputs(prompt: str, lora_name: str): + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=lora_name_id_map[lora_name], + lora_path=LORA_NAME_PATH_MAP[lora_name], + ) + messages = format_chatml_messages(prompt) + outputs = llm.chat( + [messages], + sampling_params, + chat_template_kwargs={ + "enable_thinking": False + }, # for those loras, ensure enable_thinking=False + lora_request=lora_request, + use_tqdm=False, + ) + output_text = outputs[0].outputs[0].text + return output_text + + def reload_lora(name: str): + """ + reload a lora to simulate the case: + setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true` + for dynamic lora loading and unloading + """ + remove_lora_response = llm.llm_engine.remove_lora( + lora_id=lora_name_id_map[name]) + + add_lora_response = llm.llm_engine.add_lora( + make_add_lora_request(name, LORA_NAME_PATH_MAP[name])) + + print(f"{remove_lora_response=}, {add_lora_response=}") + + def check_outputs(outputs: str, expected: str, prompt: str): + print(f"{prompt=}.\n{expected=}\n{outputs=}") + print("\n----------------------------\n") + assert outputs == expected + + for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED): + + output_text = call_llm_get_outputs(prompt, "Alice") + check_outputs(output_text, expected_output, prompt) + + # call Bob, ignore what it is output + call_llm_get_outputs(prompt, "Bob") + print("After call Bob:") + + # call Alice + output_text = call_llm_get_outputs(prompt, "Alice") + check_outputs(output_text, expected_output, prompt) + + # reload Bob Lora + reload_lora("Bob") + print("After reload Bob:") + + # call Alice + output_text = call_llm_get_outputs(prompt, "Alice") + check_outputs(output_text, expected_output, prompt) + + # reload Alice Lora + reload_lora("Alice") + print("After reload Alice:") + + output_text = call_llm_get_outputs(prompt, "Alice") + check_outputs(output_text, expected_output, prompt) \ No newline at end of file