[Feature] add multi-rank support for Lora (#4492)

Co-authored-by: rudy152 <czh1137892874@gmail.com>
This commit is contained in:
chaobo jia
2025-03-29 00:38:44 +08:00
committed by GitHub
parent 6dea5c96bf
commit ef9a378a20
16 changed files with 292 additions and 97 deletions

View File

@@ -965,7 +965,7 @@ async def benchmark(
request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool,
lora_name: str,
lora_names: List[str],
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
@@ -988,6 +988,11 @@ async def benchmark(
# Warmup
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
if lora_names != None and len(lora_names) != 0:
lora_name = lora_names[0]
else:
lora_name = None
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
@@ -1028,6 +1033,12 @@ async def benchmark(
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
if lora_names != None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx]
else:
lora_name = None
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
lora_names=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
print(f"Fail to set RLIMIT_NOFILE: {e}")
class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, [])
for lora_name in values:
getattr(namespace, self.dest).append(lora_name)
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
parser.add_argument(
"--lora-name",
type=str,
nargs="*",
default=None,
help="The name of LoRA adapter",
action=LoRAPathAction,
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
)
parser.add_argument(
"--prompt-suffix",