[Feature] add multi-rank support for Lora (#4492)
Co-authored-by: rudy152 <czh1137892874@gmail.com>
This commit is contained in:
@@ -965,7 +965,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
max_concurrency: Optional[int],
|
||||
disable_tqdm: bool,
|
||||
lora_name: str,
|
||||
lora_names: List[str],
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
pd_seperated: bool = False,
|
||||
@@ -988,6 +988,11 @@ async def benchmark(
|
||||
# Warmup
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
@@ -1028,6 +1033,12 @@ async def benchmark(
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
idx = random.randint(0, len(lora_names) - 1)
|
||||
lora_name = lora_names[idx]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
lora_names=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
pd_seperated=args.pd_seperated,
|
||||
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
|
||||
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, [])
|
||||
for lora_name in values:
|
||||
getattr(namespace, self.dest).append(lora_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--lora-name",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
action=LoRAPathAction,
|
||||
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
|
||||
Reference in New Issue
Block a user