Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)
This commit is contained in:
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
@@ -287,7 +288,6 @@ def correctness_test(
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user