diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 5b8d706a3..5e954ecd6 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -9,6 +9,7 @@ from datetime import datetime from typing import Optional import aiohttp +import numpy as np import requests from tqdm.asyncio import tqdm @@ -97,6 +98,30 @@ def parse_args(): default="performance_metrics.jsonl", help="File to log performance metrics", ) + parser.add_argument( + "--disable-auto-run", + action="store_true", + help="If set, disable automatically testing with a range of request rates.", + ) + + parser.add_argument( + "--disable-random-sample", + action="store_true", + help="If set, disable random sampling of requests from the ShareGPT dataset.", + ) + parser.add_argument( + "--sub-question-input-length", + type=int, + default=0, + help="Length of the sub question input for each request, if set 0 use request_length", + ) + parser.add_argument( + "--ready-queue-policy", + type=str, + default="random", + help="Policy for popping requests from the ready queue (random or fifo)", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") return parser.parse_args() @@ -234,13 +259,29 @@ class WorkloadGenerator: self.candidate_inputs = sample_random_requests( input_len=args.request_length, output_len=args.output_length, - num_prompts=args.num_clients * args.num_rounds, + num_prompts=args.num_clients, range_ratio=1.0, tokenizer=self.tokenizer, dataset_path=args.dataset_path, + random_sample=not args.disable_random_sample, ) self.candidate_inputs = [i.prompt for i in self.candidate_inputs] + if args.sub_question_input_length != 0: + sub_question_input_length = args.sub_question_input_length + else: + sub_question_input_length = args.request_length + + self.sub_question_inputs = sample_random_requests( + input_len=sub_question_input_length, + output_len=args.output_length, + num_prompts=args.num_clients * max(args.num_rounds - 1, 1), + range_ratio=1.0, + tokenizer=self.tokenizer, + dataset_path=args.dataset_path, + random_sample=not args.disable_random_sample, + ) + init_requests = [ (i, gen_payload(self.candidate_inputs[i], args.output_length)) for i in range(args.num_clients) @@ -249,7 +290,9 @@ class WorkloadGenerator: i: {"round": 0, "history": init_requests[i][1]["text"]} for i in range(args.num_clients) } - self.ready_queue = ReadyQueue(init_requests=init_requests) + self.ready_queue = ReadyQueue( + init_requests=init_requests, policy=args.ready_queue_policy + ) self.candidate_inputs = self.candidate_inputs[args.num_clients :] self.response_queue = queue.Queue() @@ -314,9 +357,10 @@ class WorkloadGenerator: self.completed_requests += 1 if self.client_records[client_id]["round"] < args.num_rounds: + # append new request to client's history self.client_records[client_id][ "history" - ] += self.candidate_inputs.pop() + ] += self.sub_question_inputs.pop() self.ready_queue.append( ( client_id, @@ -329,6 +373,9 @@ class WorkloadGenerator: except queue.Empty: if self.pbar.n == self.pbar.total: break + except ValueError as e: + print(f"Error processing response for client {client_id}: {e}") + continue def run(self): request_thread = threading.Thread(target=self.request_sender, daemon=True) @@ -388,8 +435,18 @@ if __name__ == "__main__": args = parse_args() flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" - for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]: - args.request_rate = request_rate + random.seed(args.seed) + np.random.seed(args.seed) + + if args.disable_auto_run: + print("Running with specified request rate...") + request_rates = [args.request_rate] + else: + print("Auto-running with different request rates...") + request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + + for rate in request_rates: + args.request_rate = rate requests.post(flush_cache_url) time.sleep(1) WorkloadGenerator(args).run()