diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index fc987626d..3eca72de4 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -51,6 +51,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + lora_name: str extra_request_body: Dict[str, Any] @@ -162,6 +163,7 @@ async def async_request_openai_completions( "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, + "lora_path": request_func_input.lora_name, **request_func_input.extra_request_body, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} @@ -319,6 +321,7 @@ async def async_request_sglang_generate( "ignore_eos": not args.disable_ignore_eos, }, "stream": not args.disable_stream, + "lora_path": request_func_input.lora_name, **request_func_input.extra_request_body, } headers = {} @@ -884,6 +887,7 @@ async def benchmark( request_rate: float, max_concurrency: Optional[int], disable_tqdm: bool, + lora_name: str, extra_request_body: Dict[str, Any], profile: bool, ): @@ -909,6 +913,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + lora_name=lora_name, extra_request_body=extra_request_body, ) test_output = await request_func(request_func_input=test_input) @@ -942,6 +947,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + lora_name=lora_name, extra_request_body=extra_request_body, ) tasks.append( @@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace): request_rate=args.request_rate, max_concurrency=args.max_concurrency, disable_tqdm=args.disable_tqdm, + lora_name=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, ) @@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace): request_rate=rate, max_concurrency=args.max_concurrency, disable_tqdm=args.disable_tqdm, + lora_name=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, ) @@ -1451,5 +1459,11 @@ if __name__ == "__main__": help="Use Torch Profiler. The endpoint must be launched with " "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) + parser.add_argument( + "--lora-name", + type=str, + default=None, + help="The name of LoRA adapter", + ) args = parser.parse_args() run_benchmark(args) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index e43a9f5b5..7dcfa6e13 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -486,6 +486,7 @@ def v1_generate_request( return_logprobs = [] logprob_start_lens = [] top_logprobs_nums = [] + lora_paths = [] for request in all_requests: # NOTE: with openai API, the prompt's logprobs are always not computed @@ -496,6 +497,7 @@ def v1_generate_request( ) prompts.append(request.prompt) + lora_paths.append(request.lora_path) if request.echo and request.logprobs: current_logprob_start_len = 0 else: @@ -534,6 +536,7 @@ def v1_generate_request( return_logprobs = return_logprobs[0] logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] + lora_paths = lora_paths[0] else: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} @@ -549,6 +552,7 @@ def v1_generate_request( return_text_in_logprobs=True, stream=all_requests[0].stream, rid=request_ids, + lora_path=lora_paths, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 301897abb..7c88ad533 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -166,6 +166,7 @@ class CompletionRequest(BaseModel): temperature: float = 1.0 top_p: float = 1.0 user: Optional[str] = None + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None # Extra parameters for SRT backend only and will be ignored by OpenAI models. json_schema: Optional[str] = None diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a1646fb5f..355294602 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -567,6 +567,7 @@ def run_bench_serving( disable_tqdm=False, disable_stream=disable_stream, disable_ignore_eos=False, + lora_name=None, extra_request_body=None, profile=None, )