Support LoRA in Completion API (#2243)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
lora_name: str
|
||||
extra_request_body: Dict[str, Any]
|
||||
|
||||
|
||||
@@ -162,6 +163,7 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": not args.disable_stream,
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
"lora_path": request_func_input.lora_name,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
@@ -319,6 +321,7 @@ async def async_request_sglang_generate(
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
},
|
||||
"stream": not args.disable_stream,
|
||||
"lora_path": request_func_input.lora_name,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {}
|
||||
@@ -884,6 +887,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
max_concurrency: Optional[int],
|
||||
disable_tqdm: bool,
|
||||
lora_name: str,
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
):
|
||||
@@ -909,6 +913,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@@ -942,6 +947,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
@@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
@@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
@@ -1451,5 +1459,11 @@ if __name__ == "__main__":
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
Reference in New Issue
Block a user