Add Support for API Key Authentication (#230)

This commit is contained in:
Alessio Dalla Piazza
2024-03-11 13:16:10 +01:00
committed by GitHub
parent 1b35547927
commit d5ae2ebaa2
4 changed files with 63 additions and 18 deletions

View File

@@ -20,6 +20,8 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server():
uvicorn.run(
app,
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
)
def _wait_and_warmup():
headers = {}
url = server_args.url()
for _ in range(60):
time.sleep(1)
if server_args.api_key and server_args.api_key != "":
headers[API_KEY_HEADER_NAME] = server_args.api_key
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5)
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException as e:
pass
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens": 16,
},
},
headers=headers,
timeout=60,
)
# print(f"Warmup done. model response: {res.json()['text']}")
@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
log_level: str = "error",
api_key: str = "",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
):
@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
log_level=log_level,
api_key=api_key,
)
self.url = self.server_args.url()

View File

@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self):
if self.tokenizer_path is None:
@@ -201,6 +202,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):