Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)
This commit is contained in:
@@ -42,13 +42,13 @@ class ModelCase:
|
||||
rouge_l_tolerance: float = 1
|
||||
|
||||
|
||||
# Popular models that run on CI
|
||||
# Popular models that run on the CI
|
||||
CI_MODELS = [
|
||||
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
||||
ModelCase("google/gemma-2-2b"),
|
||||
]
|
||||
|
||||
# All other models
|
||||
# All other models that do not run on the CI
|
||||
ALL_OTHER_MODELS = [
|
||||
ModelCase("Qwen/Qwen2-1.5B"),
|
||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestGenerationModels(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
def assert_close_logits_and_output_strs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
|
||||
return
|
||||
|
||||
for model_case in ALL_OTHER_MODELS:
|
||||
# Only run a specified model
|
||||
if (
|
||||
"ONLY_RUN" in os.environ
|
||||
and os.environ["ONLY_RUN"] != model_case.model_path
|
||||
):
|
||||
continue
|
||||
self.assert_close_logits_and_output_strs(
|
||||
DEFAULT_PROMPTS, model_case, torch.float16
|
||||
)
|
||||
|
||||
# Skip long prompts for models that does not have a long context
|
||||
prompts = DEFAULT_PROMPTS
|
||||
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn")
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user