Add lora_path argument to bench_multiturn.py (#10092)
This commit is contained in:
@@ -130,6 +130,12 @@ def parse_args():
|
||||
help="Tag of a certain run in the log file",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -205,7 +211,7 @@ async def async_request_sglang_generate(
|
||||
return output
|
||||
|
||||
|
||||
def gen_payload(prompt, output_len):
|
||||
def gen_payload(prompt, output_len, lora_path=""):
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
@@ -215,7 +221,7 @@ def gen_payload(prompt, output_len):
|
||||
},
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"lora_path": "",
|
||||
"lora_path": lora_path,
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": -1,
|
||||
}
|
||||
@@ -303,7 +309,12 @@ class WorkloadGenerator:
|
||||
)
|
||||
|
||||
init_requests = [
|
||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||
(
|
||||
i,
|
||||
gen_payload(
|
||||
self.candidate_inputs[i], args.output_length, args.lora_path
|
||||
),
|
||||
)
|
||||
for i in range(args.num_clients)
|
||||
]
|
||||
self.client_records = {
|
||||
@@ -399,6 +410,7 @@ class WorkloadGenerator:
|
||||
gen_payload(
|
||||
self.client_records[client_id]["history"],
|
||||
self.output_length,
|
||||
args.lora_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user