diff --git a/tests/e2e/singlecard/test_llama32_lora.py b/tests/e2e/singlecard/test_llama32_lora.py index 782d67df..6314014b 100644 --- a/tests/e2e/singlecard/test_llama32_lora.py +++ b/tests/e2e/singlecard/test_llama32_lora.py @@ -29,6 +29,14 @@ EXPECTED_LORA_OUTPUT = [ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 ] + +EXPECTED_BASE_MODEL_OUTPUT = [ + "SELECT COUNT(*) FROM candidate", + "`SELECT COUNT(*) FROM candidate;`", + "SELECT Poll_Source FROM candidate GROUP BY Poll_Source ORDER BY COUNT(*) DESC LIMIT 1;", + "SELECT * FROM candidate ORDER BY Candidate_ID DESC LIMIT 1", +] + # For hk region, we need to use the model from hf to avoid the network issue MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" @@ -102,6 +110,14 @@ def generate_and_test(llm, lora_id=2, ) == EXPECTED_LORA_OUTPUT) + print("base model") + assert (do_sample( + llm, + llama32_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0, + ) == EXPECTED_BASE_MODEL_OUTPUT) + print("removing lora")