Misc fixes for eagle (flush_cache, CPU overhead) (#3014)
This commit is contained in:
@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
|
||||
median_e2e_latency_ms: float
|
||||
std_e2e_latency_ms: float
|
||||
p99_e2e_latency_ms: float
|
||||
concurrency: float
|
||||
|
||||
|
||||
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
apply_chat_template=False,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
|
||||
if apply_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt = prompt.replace(tokenizer.bos_token, "")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer.encode(completion)
|
||||
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
|
||||
if prompt_len < 1 or output_len < 1:
|
||||
if prompt_len < 2 or output_len < 2:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
|
||||
@@ -880,6 +892,7 @@ def calculate_metrics(
|
||||
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
||||
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
||||
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
|
||||
concurrency=np.sum(e2e_latencies) / dur_s,
|
||||
)
|
||||
|
||||
return metrics, output_lens
|
||||
@@ -1031,6 +1044,7 @@ async def benchmark(
|
||||
"Total token throughput (tok/s):", metrics.total_throughput
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
@@ -1062,13 +1076,24 @@ async def benchmark(
|
||||
and metrics.output_throughput is not None
|
||||
):
|
||||
result = {
|
||||
# Arguments
|
||||
"backend": args.backend,
|
||||
"dataset_name": args.dataset_name,
|
||||
"request_rate": request_rate,
|
||||
"max_concurrency": max_concurrency,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
# Results
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
|
||||
@@ -1085,14 +1110,7 @@ async def benchmark(
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"concurrency": metrics.concurrency,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
@@ -1112,36 +1130,16 @@ async def benchmark(
|
||||
with open(output_file_name, "a") as file:
|
||||
file.write(json.dumps(result) + "\n")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
}
|
||||
result.update(
|
||||
{
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
@@ -1445,16 +1442,17 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--return-logprob",
|
||||
action="store_true",
|
||||
help="Return logprob.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user