[Fix] Fix logprob and normalized_logprob (#1428)

This commit is contained in:
Lianmin Zheng
2024-09-15 06:36:06 -07:00
committed by GitHub
parent 282681b8a1
commit 9ba1f09760
22 changed files with 314 additions and 215 deletions

View File

@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return input_ids, reqs
@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
return reqs
@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return reqs