[Test][LoRA] Add e2e test for base model inference (#6624)
### What this PR does / why we need it? This PR adds an end-to-end test case to verify the correctness of base model inference when LoRA is enabled. This is to ensure that after a LoRA base model request issue was fixed, the functionality remains correct and does not regress. The new test case calls `do_sample` with `lora_id=0` to target the base model and asserts the output against expected SQL queries. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with the new test case. The test can be run with: ```bash pytest -sv tests/e2e/singlecard/test_llama32_lora.py Signed-off-by: paulyu12 <507435917@qq.com>
This commit is contained in:
@@ -29,6 +29,14 @@ EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
EXPECTED_BASE_MODEL_OUTPUT = [
|
||||
"SELECT COUNT(*) FROM candidate",
|
||||
"`SELECT COUNT(*) FROM candidate;`",
|
||||
"SELECT Poll_Source FROM candidate GROUP BY Poll_Source ORDER BY COUNT(*) DESC LIMIT 1;",
|
||||
"SELECT * FROM candidate ORDER BY Candidate_ID DESC LIMIT 1",
|
||||
]
|
||||
|
||||
# For hk region, we need to use the model from hf to avoid the network issue
|
||||
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
@@ -102,6 +110,14 @@ def generate_and_test(llm,
|
||||
lora_id=2,
|
||||
) == EXPECTED_LORA_OUTPUT)
|
||||
|
||||
print("base model")
|
||||
assert (do_sample(
|
||||
llm,
|
||||
llama32_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=0,
|
||||
) == EXPECTED_BASE_MODEL_OUTPUT)
|
||||
|
||||
print("removing lora")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user