Support LoRA in Completion API (#2243)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2024-11-30 08:13:38 +08:00
committed by GitHub
parent 94e167ea5a
commit 01017d4c20
4 changed files with 20 additions and 0 deletions

View File

@@ -51,6 +51,7 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
lora_name: str
extra_request_body: Dict[str, Any]
@@ -162,6 +163,7 @@ async def async_request_openai_completions(
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
"lora_path": request_func_input.lora_name,
**request_func_input.extra_request_body,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
@@ -319,6 +321,7 @@ async def async_request_sglang_generate(
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
"lora_path": request_func_input.lora_name,
**request_func_input.extra_request_body,
}
headers = {}
@@ -884,6 +887,7 @@ async def benchmark(
request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool,
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
):
@@ -909,6 +913,7 @@ async def benchmark(
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
@@ -942,6 +947,7 @@ async def benchmark(
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
tasks.append(
@@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
@@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
@@ -1451,5 +1459,11 @@ if __name__ == "__main__":
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
args = parser.parse_args()
run_benchmark(args)