Fix bench_serving not recognizing OPENAI_API_KEY (#3870)
Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
@@ -71,6 +71,14 @@ def remove_prefix(text: str, prefix: str) -> str:
|
|||||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_headers() -> Dict[str, str]:
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# trt llm not support ignore_eos
|
# trt llm not support ignore_eos
|
||||||
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||||
async def async_request_trt_llm(
|
async def async_request_trt_llm(
|
||||||
@@ -165,7 +173,7 @@ async def async_request_openai_completions(
|
|||||||
"ignore_eos": not args.disable_ignore_eos,
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = get_auth_headers()
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
@@ -244,7 +252,7 @@ async def async_request_truss(
|
|||||||
"ignore_eos": not args.disable_ignore_eos,
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = get_auth_headers()
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
@@ -325,7 +333,7 @@ async def async_request_sglang_generate(
|
|||||||
"logprob_start_len": -1,
|
"logprob_start_len": -1,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {}
|
headers = get_auth_headers()
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
@@ -1238,7 +1246,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
try:
|
try:
|
||||||
response = requests.get(model_url)
|
response = requests.get(model_url, headers=get_auth_headers())
|
||||||
model_list = response.json().get("data", [])
|
model_list = response.json().get("data", [])
|
||||||
args.model = model_list[0]["id"] if model_list else None
|
args.model = model_list[0]["id"] if model_list else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user