[Benchmark] add disable-auto-run param for hicache/bench_multiturn (#7822)
Co-authored-by: zhongwei.ren <zhongwei.ren@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -9,6 +9,7 @@ from datetime import datetime
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
@@ -97,6 +98,30 @@ def parse_args():
|
|||||||
default="performance_metrics.jsonl",
|
default="performance_metrics.jsonl",
|
||||||
help="File to log performance metrics",
|
help="File to log performance metrics",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-auto-run",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, disable automatically testing with a range of request rates.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-random-sample",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, disable random sampling of requests from the ShareGPT dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sub-question-input-length",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Length of the sub question input for each request, if set 0 use request_length",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ready-queue-policy",
|
||||||
|
type=str,
|
||||||
|
default="random",
|
||||||
|
help="Policy for popping requests from the ready queue (random or fifo)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -234,13 +259,29 @@ class WorkloadGenerator:
|
|||||||
self.candidate_inputs = sample_random_requests(
|
self.candidate_inputs = sample_random_requests(
|
||||||
input_len=args.request_length,
|
input_len=args.request_length,
|
||||||
output_len=args.output_length,
|
output_len=args.output_length,
|
||||||
num_prompts=args.num_clients * args.num_rounds,
|
num_prompts=args.num_clients,
|
||||||
range_ratio=1.0,
|
range_ratio=1.0,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
|
random_sample=not args.disable_random_sample,
|
||||||
)
|
)
|
||||||
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
|
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
|
||||||
|
|
||||||
|
if args.sub_question_input_length != 0:
|
||||||
|
sub_question_input_length = args.sub_question_input_length
|
||||||
|
else:
|
||||||
|
sub_question_input_length = args.request_length
|
||||||
|
|
||||||
|
self.sub_question_inputs = sample_random_requests(
|
||||||
|
input_len=sub_question_input_length,
|
||||||
|
output_len=args.output_length,
|
||||||
|
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
|
||||||
|
range_ratio=1.0,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
random_sample=not args.disable_random_sample,
|
||||||
|
)
|
||||||
|
|
||||||
init_requests = [
|
init_requests = [
|
||||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||||
for i in range(args.num_clients)
|
for i in range(args.num_clients)
|
||||||
@@ -249,7 +290,9 @@ class WorkloadGenerator:
|
|||||||
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
||||||
for i in range(args.num_clients)
|
for i in range(args.num_clients)
|
||||||
}
|
}
|
||||||
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
self.ready_queue = ReadyQueue(
|
||||||
|
init_requests=init_requests, policy=args.ready_queue_policy
|
||||||
|
)
|
||||||
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||||
|
|
||||||
self.response_queue = queue.Queue()
|
self.response_queue = queue.Queue()
|
||||||
@@ -314,9 +357,10 @@ class WorkloadGenerator:
|
|||||||
self.completed_requests += 1
|
self.completed_requests += 1
|
||||||
|
|
||||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||||
|
# append new request to client's history
|
||||||
self.client_records[client_id][
|
self.client_records[client_id][
|
||||||
"history"
|
"history"
|
||||||
] += self.candidate_inputs.pop()
|
] += self.sub_question_inputs.pop()
|
||||||
self.ready_queue.append(
|
self.ready_queue.append(
|
||||||
(
|
(
|
||||||
client_id,
|
client_id,
|
||||||
@@ -329,6 +373,9 @@ class WorkloadGenerator:
|
|||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
if self.pbar.n == self.pbar.total:
|
if self.pbar.n == self.pbar.total:
|
||||||
break
|
break
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Error processing response for client {client_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||||
@@ -388,8 +435,18 @@ if __name__ == "__main__":
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||||
|
|
||||||
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
random.seed(args.seed)
|
||||||
args.request_rate = request_rate
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
if args.disable_auto_run:
|
||||||
|
print("Running with specified request rate...")
|
||||||
|
request_rates = [args.request_rate]
|
||||||
|
else:
|
||||||
|
print("Auto-running with different request rates...")
|
||||||
|
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||||
|
|
||||||
|
for rate in request_rates:
|
||||||
|
args.request_rate = rate
|
||||||
requests.post(flush_cache_url)
|
requests.post(flush_cache_url)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
WorkloadGenerator(args).run()
|
WorkloadGenerator(args).run()
|
||||||
|
|||||||
Reference in New Issue
Block a user