Add Support for API Key Authentication (#230)
This commit is contained in:
committed by
GitHub
parent
1b35547927
commit
d5ae2ebaa2
@@ -20,6 +20,8 @@ import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# extract API key from the request headers
|
||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||
if not api_key_header or api_key_header != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
def _launch_server():
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
)
|
||||
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
for _ in range(60):
|
||||
time.sleep(1)
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5)
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
pass
|
||||
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||
@@ -558,6 +585,7 @@ class Runtime:
|
||||
attention_reduce_in_fp32: bool = False,
|
||||
random_seed: int = 42,
|
||||
log_level: str = "error",
|
||||
api_key: str = "",
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
@@ -580,6 +608,7 @@ class Runtime:
|
||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
self.url = self.server_args.url()
|
||||
|
||||
@@ -32,6 +32,7 @@ class ServerArgs:
|
||||
enable_flashinfer: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
api_key: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
@@ -201,6 +202,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user