[Test][LoRA] Add e2e test for base model inference (#6624)
### What this PR does / why we need it? This PR adds an end-to-end test case to verify the correctness of base model inference when LoRA is enabled. This is to ensure that after a LoRA base model request issue was fixed, the functionality remains correct and does not regress. The new test case calls `do_sample` with `lora_id=0` to target the base model and asserts the output against expected SQL queries. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with the new test case. The test can be run with: ```bash pytest -sv tests/e2e/singlecard/test_llama32_lora.py Signed-off-by: paulyu12 <507435917@qq.com>
This commit is contained in:
@@ -29,6 +29,14 @@ EXPECTED_LORA_OUTPUT = [
|
|||||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||||
]
|
]
|
||||||
|
|
||||||
|
EXPECTED_BASE_MODEL_OUTPUT = [
|
||||||
|
"SELECT COUNT(*) FROM candidate",
|
||||||
|
"`SELECT COUNT(*) FROM candidate;`",
|
||||||
|
"SELECT Poll_Source FROM candidate GROUP BY Poll_Source ORDER BY COUNT(*) DESC LIMIT 1;",
|
||||||
|
"SELECT * FROM candidate ORDER BY Candidate_ID DESC LIMIT 1",
|
||||||
|
]
|
||||||
|
|
||||||
# For hk region, we need to use the model from hf to avoid the network issue
|
# For hk region, we need to use the model from hf to avoid the network issue
|
||||||
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
|
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
|
||||||
@@ -102,6 +110,14 @@ def generate_and_test(llm,
|
|||||||
lora_id=2,
|
lora_id=2,
|
||||||
) == EXPECTED_LORA_OUTPUT)
|
) == EXPECTED_LORA_OUTPUT)
|
||||||
|
|
||||||
|
print("base model")
|
||||||
|
assert (do_sample(
|
||||||
|
llm,
|
||||||
|
llama32_lora_files,
|
||||||
|
tensorizer_config_dict=tensorizer_config_dict,
|
||||||
|
lora_id=0,
|
||||||
|
) == EXPECTED_BASE_MODEL_OUTPUT)
|
||||||
|
|
||||||
print("removing lora")
|
print("removing lora")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user