From 8d44ddacb026ddffea993816802aa98da981992f Mon Sep 17 00:00:00 2001 From: yupeng <507435917@qq.com> Date: Mon, 9 Feb 2026 21:06:49 +0800 Subject: [PATCH] [Test][LoRA] Add e2e test for base model inference (#6624) ### What this PR does / why we need it? This PR adds an end-to-end test case to verify the correctness of base model inference when LoRA is enabled. This is to ensure that after a LoRA base model request issue was fixed, the functionality remains correct and does not regress. The new test case calls `do_sample` with `lora_id=0` to target the base model and asserts the output against expected SQL queries. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with the new test case. The test can be run with: ```bash pytest -sv tests/e2e/singlecard/test_llama32_lora.py Signed-off-by: paulyu12 <507435917@qq.com> --- tests/e2e/singlecard/test_llama32_lora.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/e2e/singlecard/test_llama32_lora.py b/tests/e2e/singlecard/test_llama32_lora.py index 782d67df..6314014b 100644 --- a/tests/e2e/singlecard/test_llama32_lora.py +++ b/tests/e2e/singlecard/test_llama32_lora.py @@ -29,6 +29,14 @@ EXPECTED_LORA_OUTPUT = [ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 ] + +EXPECTED_BASE_MODEL_OUTPUT = [ + "SELECT COUNT(*) FROM candidate", + "`SELECT COUNT(*) FROM candidate;`", + "SELECT Poll_Source FROM candidate GROUP BY Poll_Source ORDER BY COUNT(*) DESC LIMIT 1;", + "SELECT * FROM candidate ORDER BY Candidate_ID DESC LIMIT 1", +] + # For hk region, we need to use the model from hf to avoid the network issue MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" @@ -102,6 +110,14 @@ def generate_and_test(llm, lora_id=2, ) == EXPECTED_LORA_OUTPUT) + print("base model") + assert (do_sample( + llm, + llama32_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0, + ) == EXPECTED_BASE_MODEL_OUTPUT) + print("removing lora")