Add lora_path argument to bench_multiturn.py (#10092)

This commit is contained in:
Baizhou Zhang
2025-09-05 19:20:42 -07:00
committed by GitHub
parent 21b9a4b435
commit beac202bfd

View File

@@ -130,6 +130,12 @@ def parse_args():
help="Tag of a certain run in the log file", help="Tag of a certain run in the log file",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--lora-path",
type=str,
default="",
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
)
return parser.parse_args() return parser.parse_args()
@@ -205,7 +211,7 @@ async def async_request_sglang_generate(
return output return output
def gen_payload(prompt, output_len): def gen_payload(prompt, output_len, lora_path=""):
payload = { payload = {
"text": prompt, "text": prompt,
"sampling_params": { "sampling_params": {
@@ -215,7 +221,7 @@ def gen_payload(prompt, output_len):
}, },
"stream": True, "stream": True,
"stream_options": {"include_usage": True}, "stream_options": {"include_usage": True},
"lora_path": "", "lora_path": lora_path,
"return_logprob": False, "return_logprob": False,
"logprob_start_len": -1, "logprob_start_len": -1,
} }
@@ -303,7 +309,12 @@ class WorkloadGenerator:
) )
init_requests = [ init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length)) (
i,
gen_payload(
self.candidate_inputs[i], args.output_length, args.lora_path
),
)
for i in range(args.num_clients) for i in range(args.num_clients)
] ]
self.client_records = { self.client_records = {
@@ -399,6 +410,7 @@ class WorkloadGenerator:
gen_payload( gen_payload(
self.client_records[client_id]["history"], self.client_records[client_id]["history"],
self.output_length, self.output_length,
args.lora_path,
), ),
) )
) )