diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 79829766c..a3e8b0d74 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -130,6 +130,12 @@ def parse_args(): help="Tag of a certain run in the log file", ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--lora-path", + type=str, + default="", + help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.", + ) return parser.parse_args() @@ -205,7 +211,7 @@ async def async_request_sglang_generate( return output -def gen_payload(prompt, output_len): +def gen_payload(prompt, output_len, lora_path=""): payload = { "text": prompt, "sampling_params": { @@ -215,7 +221,7 @@ def gen_payload(prompt, output_len): }, "stream": True, "stream_options": {"include_usage": True}, - "lora_path": "", + "lora_path": lora_path, "return_logprob": False, "logprob_start_len": -1, } @@ -303,7 +309,12 @@ class WorkloadGenerator: ) init_requests = [ - (i, gen_payload(self.candidate_inputs[i], args.output_length)) + ( + i, + gen_payload( + self.candidate_inputs[i], args.output_length, args.lora_path + ), + ) for i in range(args.num_clients) ] self.client_records = { @@ -399,6 +410,7 @@ class WorkloadGenerator: gen_payload( self.client_records[client_id]["history"], self.output_length, + args.lora_path, ), ) )