Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)

This commit is contained in:
Lianmin Zheng
2024-10-11 05:03:20 -07:00
committed by GitHub
parent bbd72bfc86
commit aba9eae4c6
2 changed files with 17 additions and 8 deletions

View File

@@ -42,13 +42,13 @@ class ModelCase:
rouge_l_tolerance: float = 1
# Popular models that run on CI
# Popular models that run on the CI
CI_MODELS = [
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
# All other models
# All other models that do not run on the CI
ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
return
for model_case in ALL_OTHER_MODELS:
# Only run a specified model
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch.float16
)
# Skip long prompts for models that does not have a long context
prompts = DEFAULT_PROMPTS
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
if __name__ == "__main__":
mp.set_start_method("spawn")
unittest.main()