Support LoRA in Completion API (#2243)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -51,6 +51,7 @@ class RequestFuncInput:
|
|||||||
prompt_len: int
|
prompt_len: int
|
||||||
output_len: int
|
output_len: int
|
||||||
model: str
|
model: str
|
||||||
|
lora_name: str
|
||||||
extra_request_body: Dict[str, Any]
|
extra_request_body: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@@ -162,6 +163,7 @@ async def async_request_openai_completions(
|
|||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"stream": not args.disable_stream,
|
"stream": not args.disable_stream,
|
||||||
"ignore_eos": not args.disable_ignore_eos,
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
|
"lora_path": request_func_input.lora_name,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
@@ -319,6 +321,7 @@ async def async_request_sglang_generate(
|
|||||||
"ignore_eos": not args.disable_ignore_eos,
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
},
|
},
|
||||||
"stream": not args.disable_stream,
|
"stream": not args.disable_stream,
|
||||||
|
"lora_path": request_func_input.lora_name,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {}
|
headers = {}
|
||||||
@@ -884,6 +887,7 @@ async def benchmark(
|
|||||||
request_rate: float,
|
request_rate: float,
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
|
lora_name: str,
|
||||||
extra_request_body: Dict[str, Any],
|
extra_request_body: Dict[str, Any],
|
||||||
profile: bool,
|
profile: bool,
|
||||||
):
|
):
|
||||||
@@ -909,6 +913,7 @@ async def benchmark(
|
|||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
|
lora_name=lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
test_output = await request_func(request_func_input=test_input)
|
test_output = await request_func(request_func_input=test_input)
|
||||||
@@ -942,6 +947,7 @@ async def benchmark(
|
|||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
output_len=output_len,
|
output_len=output_len,
|
||||||
|
lora_name=lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
@@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
request_rate=args.request_rate,
|
request_rate=args.request_rate,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
lora_name=args.lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
)
|
)
|
||||||
@@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
request_rate=rate,
|
request_rate=rate,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
lora_name=args.lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
)
|
)
|
||||||
@@ -1451,5 +1459,11 @@ if __name__ == "__main__":
|
|||||||
help="Use Torch Profiler. The endpoint must be launched with "
|
help="Use Torch Profiler. The endpoint must be launched with "
|
||||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of LoRA adapter",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_benchmark(args)
|
run_benchmark(args)
|
||||||
|
|||||||
@@ -486,6 +486,7 @@ def v1_generate_request(
|
|||||||
return_logprobs = []
|
return_logprobs = []
|
||||||
logprob_start_lens = []
|
logprob_start_lens = []
|
||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
|
lora_paths = []
|
||||||
|
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
@@ -496,6 +497,7 @@ def v1_generate_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompts.append(request.prompt)
|
prompts.append(request.prompt)
|
||||||
|
lora_paths.append(request.lora_path)
|
||||||
if request.echo and request.logprobs:
|
if request.echo and request.logprobs:
|
||||||
current_logprob_start_len = 0
|
current_logprob_start_len = 0
|
||||||
else:
|
else:
|
||||||
@@ -534,6 +536,7 @@ def v1_generate_request(
|
|||||||
return_logprobs = return_logprobs[0]
|
return_logprobs = return_logprobs[0]
|
||||||
logprob_start_lens = logprob_start_lens[0]
|
logprob_start_lens = logprob_start_lens[0]
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
|
lora_paths = lora_paths[0]
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
@@ -549,6 +552,7 @@ def v1_generate_request(
|
|||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
rid=request_ids,
|
rid=request_ids,
|
||||||
|
lora_path=lora_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
|
|||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
top_p: float = 1.0
|
top_p: float = 1.0
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
json_schema: Optional[str] = None
|
json_schema: Optional[str] = None
|
||||||
|
|||||||
@@ -567,6 +567,7 @@ def run_bench_serving(
|
|||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
|
lora_name=None,
|
||||||
extra_request_body=None,
|
extra_request_body=None,
|
||||||
profile=None,
|
profile=None,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user