From 8c2ffaaf0f59b22ced0d2076a8d74bccc54ad55f Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Mon, 1 Sep 2025 05:51:18 +0800 Subject: [PATCH] fix(hicahce-long-bench): adjust context workload generator to use full query set (#9847) Co-authored-by: Zhiqiang Xie --- benchmark/hicache/bench_long_context.py | 11 ++++++++--- benchmark/hicache/bench_multiturn.py | 20 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/benchmark/hicache/bench_long_context.py b/benchmark/hicache/bench_long_context.py index dc153b8a9..eed0ae5dc 100644 --- a/benchmark/hicache/bench_long_context.py +++ b/benchmark/hicache/bench_long_context.py @@ -31,9 +31,10 @@ class ContextWorkloadGenerator(WorkloadGenerator): self.completed_requests = 0 self.dataset = json.load(open(args.dataset_path)) + num_requests = min(args.num_clients, len(self.dataset["queries"])) init_requests = [] - for i in range(min(args.num_clients, len(self.dataset["queries"]))): + for i in range(num_requests): context_id = self.dataset["queries"][i]["context"] init_requests.append( ( @@ -52,13 +53,14 @@ class ContextWorkloadGenerator(WorkloadGenerator): self.ready_queue = ReadyQueue(init_requests=init_requests) self.response_queue = queue.Queue() - self.pbar = tqdm(total=args.num_clients * args.num_rounds) + self.pbar = tqdm(total=num_requests) self.performance_metrics = { "ttft": [], "latency": [], "itl": [], "prompt_len": [], "cached_tokens": [], + "generated_len": [], } self.max_parallel = args.max_parallel @@ -75,6 +77,9 @@ class ContextWorkloadGenerator(WorkloadGenerator): self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["itl"].extend(response.itl) self.performance_metrics["latency"].append(response.latency) + self.performance_metrics["prompt_len"].append(response.prompt_len) + self.performance_metrics["cached_tokens"].append(response.cached_tokens) + self.performance_metrics["generated_len"].append(response.generated_len) self.completed_requests += 1 except queue.Empty: @@ -85,7 +90,7 @@ class ContextWorkloadGenerator(WorkloadGenerator): if __name__ == "__main__": args = parse_args() args.num_rounds = 1 - args.max_parallel = 128 + args.max_parallel = 24 flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" for request_rate in [24, 16, 12, 8, 4, 2, 1]: diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 35e638d33..79829766c 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -191,6 +191,7 @@ async def async_request_sglang_generate( output.latency = latency output.prompt_len = prompt_tokens output.cached_tokens = cached_tokens + output.generated_len = len(output.itl) + 1 else: output.error = response.reason or "" output.success = False @@ -321,6 +322,7 @@ class WorkloadGenerator: "latency": [], "prompt_len": [], "cached_tokens": [], + "generated_len": [], } self.num_rounds = args.num_rounds self.max_parallel = args.max_parallel @@ -383,6 +385,7 @@ class WorkloadGenerator: self.performance_metrics["latency"].append(response.latency) self.performance_metrics["prompt_len"].append(response.prompt_len) self.performance_metrics["cached_tokens"].append(response.cached_tokens) + self.performance_metrics["generated_len"].append(response.generated_len) self.completed_requests += 1 if self.client_records[client_id]["round"] < self.num_rounds: @@ -418,6 +421,7 @@ class WorkloadGenerator: response_thread.join() self.pbar.close() + duration = self.finished_time - self.start_time performance_data = { "summary": { "total_requests": len(self.performance_metrics["ttft"]), @@ -438,7 +442,13 @@ class WorkloadGenerator: "median_latency": sorted(self.performance_metrics["latency"])[ len(self.performance_metrics["latency"]) // 2 ], - "throughput": self.pbar.total / (self.finished_time - self.start_time), + "input_token_throughput": sum(self.performance_metrics["prompt_len"]) + / duration, + "output_token_throughput": sum( + self.performance_metrics["generated_len"] + ) + / duration, + "throughput": self.pbar.total / duration, "cache_hit_rate": ( 0 if sum(self.performance_metrics["prompt_len"]) == 0 @@ -461,7 +471,13 @@ class WorkloadGenerator: print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}") print(f" Median latency: {performance_data['summary']['median_latency']:.2f}") print( - f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" + f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second" + ) + print( + f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second" + ) + print( + f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") return performance_data