Add lora_path argument to bench_multiturn.py (#10092)
This commit is contained in:
@@ -130,6 +130,12 @@ def parse_args():
|
|||||||
help="Tag of a certain run in the log file",
|
help="Tag of a certain run in the log file",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-path",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -205,7 +211,7 @@ async def async_request_sglang_generate(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def gen_payload(prompt, output_len):
|
def gen_payload(prompt, output_len, lora_path=""):
|
||||||
payload = {
|
payload = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
@@ -215,7 +221,7 @@ def gen_payload(prompt, output_len):
|
|||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
"lora_path": "",
|
"lora_path": lora_path,
|
||||||
"return_logprob": False,
|
"return_logprob": False,
|
||||||
"logprob_start_len": -1,
|
"logprob_start_len": -1,
|
||||||
}
|
}
|
||||||
@@ -303,7 +309,12 @@ class WorkloadGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
init_requests = [
|
init_requests = [
|
||||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
(
|
||||||
|
i,
|
||||||
|
gen_payload(
|
||||||
|
self.candidate_inputs[i], args.output_length, args.lora_path
|
||||||
|
),
|
||||||
|
)
|
||||||
for i in range(args.num_clients)
|
for i in range(args.num_clients)
|
||||||
]
|
]
|
||||||
self.client_records = {
|
self.client_records = {
|
||||||
@@ -399,6 +410,7 @@ class WorkloadGenerator:
|
|||||||
gen_payload(
|
gen_payload(
|
||||||
self.client_records[client_id]["history"],
|
self.client_records[client_id]["history"],
|
||||||
self.output_length,
|
self.output_length,
|
||||||
|
args.lora_path,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user