Add Support for API Key Authentication (#230)

This commit is contained in:
Alessio Dalla Piazza
2024-03-11 13:16:10 +01:00
committed by GitHub
parent 1b35547927
commit d5ae2ebaa2
4 changed files with 63 additions and 18 deletions

View File

@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, verify=None):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.auth_token = auth_token
self.api_key = api_key
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
assert res.status_code == 200
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
assert res.status_code == 200
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
obj = res.json()
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
pos = 0
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200

View File

@@ -20,6 +20,8 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server():
uvicorn.run(
app,
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
)
def _wait_and_warmup():
headers = {}
url = server_args.url()
for _ in range(60):
time.sleep(1)
if server_args.api_key and server_args.api_key != "":
headers[API_KEY_HEADER_NAME] = server_args.api_key
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5)
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException as e:
pass
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens": 16,
},
},
headers=headers,
timeout=60,
)
# print(f"Warmup done. model response: {res.json()['text']}")
@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
log_level: str = "error",
api_key: str = "",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
):
@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
log_level=log_level,
api_key=api_key,
)
self.url = self.server_args.url()

View File

@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self):
if self.tokenizer_path is None:
@@ -201,6 +202,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):

View File

@@ -88,23 +88,22 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, verify=None):
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
# add the Authorization header if an auth token is provided
if auth_token is not None:
headers["Authorization"] = f"Bearer {auth_token}"
# add the API Key header if an API key is provided
if api_key is not None:
headers["X-API-Key"] = api_key
if stream:
if auth_token is None:
return requests.post(url, json=json, stream=True, verify=verify)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}",
}
return requests.post(
url, json=json, stream=True, headers=headers, verify=verify
)
return requests.post(url, json=json, stream=True, headers=headers)
else:
req = urllib.request.Request(url)
req.add_header("Content-Type", "application/json; charset=utf-8")
if auth_token is not None:
req.add_header("Authentication", f"Bearer {auth_token}")
req = urllib.request.Request(url, headers=headers)
if json is None:
data = None
else: