Support LoRA in Completion API (#2243)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
lora_name: str
|
||||
extra_request_body: Dict[str, Any]
|
||||
|
||||
|
||||
@@ -162,6 +163,7 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": not args.disable_stream,
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
"lora_path": request_func_input.lora_name,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
@@ -319,6 +321,7 @@ async def async_request_sglang_generate(
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
},
|
||||
"stream": not args.disable_stream,
|
||||
"lora_path": request_func_input.lora_name,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {}
|
||||
@@ -884,6 +887,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
max_concurrency: Optional[int],
|
||||
disable_tqdm: bool,
|
||||
lora_name: str,
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
):
|
||||
@@ -909,6 +913,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@@ -942,6 +947,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
@@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
@@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
@@ -1451,5 +1459,11 @@ if __name__ == "__main__":
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
@@ -486,6 +486,7 @@ def v1_generate_request(
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
|
||||
for request in all_requests:
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -496,6 +497,7 @@ def v1_generate_request(
|
||||
)
|
||||
|
||||
prompts.append(request.prompt)
|
||||
lora_paths.append(request.lora_path)
|
||||
if request.echo and request.logprobs:
|
||||
current_logprob_start_len = 0
|
||||
else:
|
||||
@@ -534,6 +536,7 @@ def v1_generate_request(
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
@@ -549,6 +552,7 @@ def v1_generate_request(
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
)
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
json_schema: Optional[str] = None
|
||||
|
||||
@@ -567,6 +567,7 @@ def run_bench_serving(
|
||||
disable_tqdm=False,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=False,
|
||||
lora_name=None,
|
||||
extra_request_body=None,
|
||||
profile=None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user