diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 10ce965be..41e1a6109 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -71,6 +71,14 @@ def remove_prefix(text: str, prefix: str) -> str: return text[len(prefix) :] if text.startswith(prefix) else text +def get_auth_headers() -> Dict[str, str]: + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + return {"Authorization": f"Bearer {api_key}"} + else: + return {} + + # trt llm not support ignore_eos # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 async def async_request_trt_llm( @@ -165,7 +173,7 @@ async def async_request_openai_completions( "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, } - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = get_auth_headers() output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -244,7 +252,7 @@ async def async_request_truss( "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, } - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = get_auth_headers() output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -325,7 +333,7 @@ async def async_request_sglang_generate( "logprob_start_len": -1, **request_func_input.extra_request_body, } - headers = {} + headers = get_auth_headers() output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -1238,7 +1246,7 @@ def run_benchmark(args_: argparse.Namespace): ) sys.exit(1) try: - response = requests.get(model_url) + response = requests.get(model_url, headers=get_auth_headers()) model_list = response.json().get("data", []) args.model = model_list[0]["id"] if model_list else None except Exception as e: