Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)

This commit is contained in:
Lianmin Zheng
2024-10-11 05:03:20 -07:00
committed by GitHub
parent bbd72bfc86
commit aba9eae4c6
2 changed files with 17 additions and 8 deletions

View File

@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs
@torch.inference_mode()
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode()
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
model_worker_batch = batch.get_model_worker_batch()
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
@torch.inference_mode()
def correctness_test(
server_args,
port_args,
@@ -287,7 +288,6 @@ def correctness_test(
rank_print(tokenizer.decode(output_ids[i]), "\n")
@torch.inference_mode()
def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
):