diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 1c0d44540..fc5774ef5 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request class RuntimeEndpoint(BaseBackend): - def __init__(self, base_url, auth_token=None, verify=None): + def __init__(self, base_url, auth_token=None, api_key=None, verify=None): super().__init__() self.support_concate_and_append = True self.base_url = base_url self.auth_token = auth_token + self.api_key = api_key self.verify = verify res = http_request( self.base_url + "/get_model_info", auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, + api_key=self.api_key verify=self.verify, ) assert res.status_code == 200 @@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, + api_key=self.api_key verify=self.verify, ) assert res.status_code == 200 @@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json=data, auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json=data, auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) obj = res.json() @@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend): json=data, stream=True, auth_token=self.auth_token, + api_key=self.api_key verify=self.verify, ) pos = 0 @@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json=data, auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json=data, auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/concate_and_append_request", json={"src_rids": src_rids, "dst_rid": dst_rid}, auth_token=self.auth_token, + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 95b7b439c..3d853bd92 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -20,6 +20,8 @@ import requests import uvicorn import uvloop from fastapi import FastAPI, HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel from sglang.backend.runtime_endpoint import RuntimeEndpoint @@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +API_KEY_HEADER_NAME = "X-API-Key" + +class APIKeyValidatorMiddleware(BaseHTTPMiddleware): + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next): + # extract API key from the request headers + api_key_header = request.headers.get(API_KEY_HEADER_NAME) + if not api_key_header or api_key_header != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid API Key"}, + ) + response = await call_next(request) + return response app = FastAPI() tokenizer_manager = None @@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer): assert proc_router.is_alive() and proc_detoken.is_alive() + if server_args.api_key and server_args.api_key != "": + app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) + def _launch_server(): uvicorn.run( app, @@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer): ) def _wait_and_warmup(): + headers = {} url = server_args.url() - for _ in range(60): - time.sleep(1) + if server_args.api_key and server_args.api_key != "": + headers[API_KEY_HEADER_NAME] = server_args.api_key + + for _ in range(120): + time.sleep(0.5) try: - requests.get(url + "/get_model_info", timeout=5) + requests.get(url + "/get_model_info", timeout=5, headers=headers) break except requests.exceptions.RequestException as e: pass @@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer): "max_new_tokens": 16, }, }, + headers=headers, timeout=60, ) # print(f"Warmup done. model response: {res.json()['text']}") @@ -558,6 +585,7 @@ class Runtime: attention_reduce_in_fp32: bool = False, random_seed: int = 42, log_level: str = "error", + api_key: str = "", port: Optional[int] = None, additional_ports: Optional[Union[List[int], int]] = None, ): @@ -580,6 +608,7 @@ class Runtime: attention_reduce_in_fp32=attention_reduce_in_fp32, random_seed=random_seed, log_level=log_level, + api_key=api_key, ) self.url = self.server_args.url() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f236a9ae1..b59fd1c0c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -32,6 +32,7 @@ class ServerArgs: enable_flashinfer: bool = False disable_regex_jump_forward: bool = False disable_disk_cache: bool = False + api_key: str = "" def __post_init__(self): if self.tokenizer_path is None: @@ -201,6 +202,12 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--api-key", + type=str, + default=ServerArgs.api_key, + help="Set API Key", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 143d4b8f7..e7638e6a4 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -88,23 +88,22 @@ class HttpResponse: return self.resp.status -def http_request(url, json=None, stream=False, auth_token=None, verify=None): +def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None): """A faster version of requests.post with low-level urllib API.""" + headers = {"Content-Type": "application/json; charset=utf-8"} + + # add the Authorization header if an auth token is provided + if auth_token is not None: + headers["Authorization"] = f"Bearer {auth_token}" + + # add the API Key header if an API key is provided + if api_key is not None: + headers["X-API-Key"] = api_key + if stream: - if auth_token is None: - return requests.post(url, json=json, stream=True, verify=verify) - headers = { - "Content-Type": "application/json", - "Authentication": f"Bearer {auth_token}", - } - return requests.post( - url, json=json, stream=True, headers=headers, verify=verify - ) + return requests.post(url, json=json, stream=True, headers=headers) else: - req = urllib.request.Request(url) - req.add_header("Content-Type", "application/json; charset=utf-8") - if auth_token is not None: - req.add_header("Authentication", f"Bearer {auth_token}") + req = urllib.request.Request(url, headers=headers) if json is None: data = None else: