[Benchmark] add disable-auto-run param for hicache/bench_multiturn (#7822)
Co-authored-by: zhongwei.ren <zhongwei.ren@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -9,6 +9,7 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
@@ -97,6 +98,30 @@ def parse_args():
|
||||
default="performance_metrics.jsonl",
|
||||
help="File to log performance metrics",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-auto-run",
|
||||
action="store_true",
|
||||
help="If set, disable automatically testing with a range of request rates.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-random-sample",
|
||||
action="store_true",
|
||||
help="If set, disable random sampling of requests from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sub-question-input-length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of the sub question input for each request, if set 0 use request_length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ready-queue-policy",
|
||||
type=str,
|
||||
default="random",
|
||||
help="Policy for popping requests from the ready queue (random or fifo)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -234,13 +259,29 @@ class WorkloadGenerator:
|
||||
self.candidate_inputs = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients * args.num_rounds,
|
||||
num_prompts=args.num_clients,
|
||||
range_ratio=1.0,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
)
|
||||
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
|
||||
|
||||
if args.sub_question_input_length != 0:
|
||||
sub_question_input_length = args.sub_question_input_length
|
||||
else:
|
||||
sub_question_input_length = args.request_length
|
||||
|
||||
self.sub_question_inputs = sample_random_requests(
|
||||
input_len=sub_question_input_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
|
||||
range_ratio=1.0,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
)
|
||||
|
||||
init_requests = [
|
||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||
for i in range(args.num_clients)
|
||||
@@ -249,7 +290,9 @@ class WorkloadGenerator:
|
||||
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||
self.ready_queue = ReadyQueue(
|
||||
init_requests=init_requests, policy=args.ready_queue_policy
|
||||
)
|
||||
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
@@ -314,9 +357,10 @@ class WorkloadGenerator:
|
||||
self.completed_requests += 1
|
||||
|
||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||
# append new request to client's history
|
||||
self.client_records[client_id][
|
||||
"history"
|
||||
] += self.candidate_inputs.pop()
|
||||
] += self.sub_question_inputs.pop()
|
||||
self.ready_queue.append(
|
||||
(
|
||||
client_id,
|
||||
@@ -329,6 +373,9 @@ class WorkloadGenerator:
|
||||
except queue.Empty:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Error processing response for client {client_id}: {e}")
|
||||
continue
|
||||
|
||||
def run(self):
|
||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||
@@ -388,8 +435,18 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
||||
args.request_rate = request_rate
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.disable_auto_run:
|
||||
print("Running with specified request rate...")
|
||||
request_rates = [args.request_rate]
|
||||
else:
|
||||
print("Auto-running with different request rates...")
|
||||
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
|
||||
for rate in request_rates:
|
||||
args.request_rate = rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
WorkloadGenerator(args).run()
|
||||
|
||||
Reference in New Issue
Block a user