[Benchmark] add disable-auto-run param for hicache/bench_multiturn (#7822)

Co-authored-by: zhongwei.ren <zhongwei.ren@bytedance.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
zhongwei
2025-07-23 05:02:40 +08:00
committed by GitHub
parent 0f8b538614
commit ff45ab7a5f

View File

@@ -9,6 +9,7 @@ from datetime import datetime
from typing import Optional from typing import Optional
import aiohttp import aiohttp
import numpy as np
import requests import requests
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
@@ -97,6 +98,30 @@ def parse_args():
default="performance_metrics.jsonl", default="performance_metrics.jsonl",
help="File to log performance metrics", help="File to log performance metrics",
) )
parser.add_argument(
"--disable-auto-run",
action="store_true",
help="If set, disable automatically testing with a range of request rates.",
)
parser.add_argument(
"--disable-random-sample",
action="store_true",
help="If set, disable random sampling of requests from the ShareGPT dataset.",
)
parser.add_argument(
"--sub-question-input-length",
type=int,
default=0,
help="Length of the sub question input for each request, if set 0 use request_length",
)
parser.add_argument(
"--ready-queue-policy",
type=str,
default="random",
help="Policy for popping requests from the ready queue (random or fifo)",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
return parser.parse_args() return parser.parse_args()
@@ -234,13 +259,29 @@ class WorkloadGenerator:
self.candidate_inputs = sample_random_requests( self.candidate_inputs = sample_random_requests(
input_len=args.request_length, input_len=args.request_length,
output_len=args.output_length, output_len=args.output_length,
num_prompts=args.num_clients * args.num_rounds, num_prompts=args.num_clients,
range_ratio=1.0, range_ratio=1.0,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
) )
self.candidate_inputs = [i.prompt for i in self.candidate_inputs] self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
if args.sub_question_input_length != 0:
sub_question_input_length = args.sub_question_input_length
else:
sub_question_input_length = args.request_length
self.sub_question_inputs = sample_random_requests(
input_len=sub_question_input_length,
output_len=args.output_length,
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
init_requests = [ init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length)) (i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients) for i in range(args.num_clients)
@@ -249,7 +290,9 @@ class WorkloadGenerator:
i: {"round": 0, "history": init_requests[i][1]["text"]} i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients) for i in range(args.num_clients)
} }
self.ready_queue = ReadyQueue(init_requests=init_requests) self.ready_queue = ReadyQueue(
init_requests=init_requests, policy=args.ready_queue_policy
)
self.candidate_inputs = self.candidate_inputs[args.num_clients :] self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue() self.response_queue = queue.Queue()
@@ -314,9 +357,10 @@ class WorkloadGenerator:
self.completed_requests += 1 self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds: if self.client_records[client_id]["round"] < args.num_rounds:
# append new request to client's history
self.client_records[client_id][ self.client_records[client_id][
"history" "history"
] += self.candidate_inputs.pop() ] += self.sub_question_inputs.pop()
self.ready_queue.append( self.ready_queue.append(
( (
client_id, client_id,
@@ -329,6 +373,9 @@ class WorkloadGenerator:
except queue.Empty: except queue.Empty:
if self.pbar.n == self.pbar.total: if self.pbar.n == self.pbar.total:
break break
except ValueError as e:
print(f"Error processing response for client {client_id}: {e}")
continue
def run(self): def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True) request_thread = threading.Thread(target=self.request_sender, daemon=True)
@@ -388,8 +435,18 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]: random.seed(args.seed)
args.request_rate = request_rate np.random.seed(args.seed)
if args.disable_auto_run:
print("Running with specified request rate...")
request_rates = [args.request_rate]
else:
print("Auto-running with different request rates...")
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
for rate in request_rates:
args.request_rate = rate
requests.post(flush_cache_url) requests.post(flush_cache_url)
time.sleep(1) time.sleep(1)
WorkloadGenerator(args).run() WorkloadGenerator(args).run()