Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)
This commit is contained in:
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def extend(reqs, model_runner):
|
def extend(reqs, model_runner):
|
||||||
batch = ScheduleBatch.init_new(
|
batch = ScheduleBatch.init_new(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
|
|||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids)
|
batch.prepare_for_decode(input_token_ids)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
|
|||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def correctness_test(
|
def correctness_test(
|
||||||
server_args,
|
server_args,
|
||||||
port_args,
|
port_args,
|
||||||
@@ -287,7 +288,6 @@ def correctness_test(
|
|||||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def latency_test_run_once(
|
def latency_test_run_once(
|
||||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -42,13 +42,13 @@ class ModelCase:
|
|||||||
rouge_l_tolerance: float = 1
|
rouge_l_tolerance: float = 1
|
||||||
|
|
||||||
|
|
||||||
# Popular models that run on CI
|
# Popular models that run on the CI
|
||||||
CI_MODELS = [
|
CI_MODELS = [
|
||||||
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
ModelCase("google/gemma-2-2b"),
|
ModelCase("google/gemma-2-2b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# All other models
|
# All other models that do not run on the CI
|
||||||
ALL_OTHER_MODELS = [
|
ALL_OTHER_MODELS = [
|
||||||
ModelCase("Qwen/Qwen2-1.5B"),
|
ModelCase("Qwen/Qwen2-1.5B"),
|
||||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
|
|||||||
|
|
||||||
|
|
||||||
class TestGenerationModels(unittest.TestCase):
|
class TestGenerationModels(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
def assert_close_logits_and_output_strs(
|
def assert_close_logits_and_output_strs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for model_case in ALL_OTHER_MODELS:
|
for model_case in ALL_OTHER_MODELS:
|
||||||
|
# Only run a specified model
|
||||||
if (
|
if (
|
||||||
"ONLY_RUN" in os.environ
|
"ONLY_RUN" in os.environ
|
||||||
and os.environ["ONLY_RUN"] != model_case.model_path
|
and os.environ["ONLY_RUN"] != model_case.model_path
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
self.assert_close_logits_and_output_strs(
|
|
||||||
DEFAULT_PROMPTS, model_case, torch.float16
|
# Skip long prompts for models that does not have a long context
|
||||||
)
|
prompts = DEFAULT_PROMPTS
|
||||||
|
if model_case.model_path in ["HuggingFaceTB/SmolLM-135M-Instruct"]:
|
||||||
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
|
||||||
|
# Assert the logits and output strs are close
|
||||||
|
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mp.set_start_method("spawn")
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user