[Test][LoRA] Add e2e test for base model inference (#6624)

### What this PR does / why we need it?

This PR adds an end-to-end test case to verify the correctness of base
model inference when LoRA is enabled. This is to ensure that after a
LoRA base model request issue was fixed, the functionality remains
correct and does not regress. The new test case calls `do_sample` with
`lora_id=0` to target the base model and asserts the output against
expected SQL queries.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

CI passed with the new test case. The test can be run with:
```bash
pytest -sv tests/e2e/singlecard/test_llama32_lora.py

Signed-off-by: paulyu12 <507435917@qq.com>
This commit is contained in:
yupeng
2026-02-09 21:06:49 +08:00
committed by GitHub
parent 156976b982
commit 8d44ddacb0

View File

@@ -29,6 +29,14 @@ EXPECTED_LORA_OUTPUT = [
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
] ]
EXPECTED_BASE_MODEL_OUTPUT = [
"SELECT COUNT(*) FROM candidate",
"`SELECT COUNT(*) FROM candidate;`",
"SELECT Poll_Source FROM candidate GROUP BY Poll_Source ORDER BY COUNT(*) DESC LIMIT 1;",
"SELECT * FROM candidate ORDER BY Candidate_ID DESC LIMIT 1",
]
# For hk region, we need to use the model from hf to avoid the network issue # For hk region, we need to use the model from hf to avoid the network issue
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
@@ -102,6 +110,14 @@ def generate_and_test(llm,
lora_id=2, lora_id=2,
) == EXPECTED_LORA_OUTPUT) ) == EXPECTED_LORA_OUTPUT)
print("base model")
assert (do_sample(
llm,
llama32_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=0,
) == EXPECTED_BASE_MODEL_OUTPUT)
print("removing lora") print("removing lora")