Add lora_path argument to bench_multiturn.py (#10092)

This commit is contained in:
Baizhou Zhang
2025-09-05 19:20:42 -07:00
committed by GitHub
parent 21b9a4b435
commit beac202bfd

View File

@@ -130,6 +130,12 @@ def parse_args():
help="Tag of a certain run in the log file",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--lora-path",
type=str,
default="",
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
)
return parser.parse_args()
@@ -205,7 +211,7 @@ async def async_request_sglang_generate(
return output
def gen_payload(prompt, output_len):
def gen_payload(prompt, output_len, lora_path=""):
payload = {
"text": prompt,
"sampling_params": {
@@ -215,7 +221,7 @@ def gen_payload(prompt, output_len):
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"lora_path": lora_path,
"return_logprob": False,
"logprob_start_len": -1,
}
@@ -303,7 +309,12 @@ class WorkloadGenerator:
)
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
(
i,
gen_payload(
self.candidate_inputs[i], args.output_length, args.lora_path
),
)
for i in range(args.num_clients)
]
self.client_records = {
@@ -399,6 +410,7 @@ class WorkloadGenerator:
gen_payload(
self.client_records[client_id]["history"],
self.output_length,
args.lora_path,
),
)
)