Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)
This commit is contained in:
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
@@ -287,7 +288,6 @@ def correctness_test(
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
):
|
||||
|
||||
@@ -42,13 +42,13 @@ class ModelCase:
|
||||
rouge_l_tolerance: float = 1
|
||||
|
||||
|
||||
# Popular models that run on CI
|
||||
# Popular models that run on the CI
|
||||
CI_MODELS = [
|
||||
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
||||
ModelCase("google/gemma-2-2b"),
|
||||
]
|
||||
|
||||
# All other models
|
||||
# All other models that do not run on the CI
|
||||
ALL_OTHER_MODELS = [
|
||||
ModelCase("Qwen/Qwen2-1.5B"),
|
||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestGenerationModels(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
def assert_close_logits_and_output_strs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
|
||||
return
|
||||
|
||||
for model_case in ALL_OTHER_MODELS:
|
||||
# Only run a specified model
|
||||
if (
|
||||
"ONLY_RUN" in os.environ
|
||||
and os.environ["ONLY_RUN"] != model_case.model_path
|
||||
):
|
||||
continue
|
||||
self.assert_close_logits_and_output_strs(
|
||||
DEFAULT_PROMPTS, model_case, torch.float16
|
||||
)
|
||||
|
||||
# Skip long prompts for models that does not have a long context
|
||||
prompts = DEFAULT_PROMPTS
|
||||
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn")
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user