Misc fixes for eagle (flush_cache, CPU overhead) (#3014)

This commit is contained in:
Lianmin Zheng
2025-01-20 20:25:13 -08:00
parent d2571dd5c7
commit 287d07a669
11 changed files with 133 additions and 96 deletions

View File

@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float
std_e2e_latency_ms: float
p99_e2e_latency_ms: float
concurrency: float
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
prompt = prompt.replace(tokenizer.bos_token, "")
prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1]
completion_token_ids = tokenizer.encode(completion)
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 1 or output_len < 1:
if prompt_len < 2 or output_len < 2:
# Prune too short sequences.
continue
@@ -880,6 +892,7 @@ def calculate_metrics(
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
concurrency=np.sum(e2e_latencies) / dur_s,
)
return metrics, output_lens
@@ -1031,6 +1044,7 @@ async def benchmark(
"Total token throughput (tok/s):", metrics.total_throughput
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1062,13 +1076,24 @@ async def benchmark(
and metrics.output_throughput is not None
):
result = {
# Arguments
"backend": args.backend,
"dataset_name": args.dataset_name,
"request_rate": request_rate,
"max_concurrency": max_concurrency,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
# Results
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
@@ -1085,14 +1110,7 @@ async def benchmark(
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
"duration": benchmark_duration,
"completed": metrics.completed,
"concurrency": metrics.concurrency,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
@@ -1112,36 +1130,16 @@ async def benchmark(
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
}
result.update(
{
"input_lens": [output.prompt_len for output in outputs],
"output_lens": output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
)
return result
@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--multi",
action="store_true",
@@ -1445,16 +1442,17 @@ if __name__ == "__main__":
action="store_true",
help="Disable streaming mode.",
)
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--return-logprob",
action="store_true",
help="Return logprob.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",