Add Support for API Key Authentication (#230)
This commit is contained in:
committed by
GitHub
parent
1b35547927
commit
d5ae2ebaa2
@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(self, base_url, auth_token=None, verify=None):
|
||||
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.api_key = api_key
|
||||
self.verify = verify
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info",
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
obj = res.json()
|
||||
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
json=data,
|
||||
stream=True,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
verify=self.verify,
|
||||
)
|
||||
pos = 0
|
||||
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -20,6 +20,8 @@ import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# extract API key from the request headers
|
||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||
if not api_key_header or api_key_header != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
def _launch_server():
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
)
|
||||
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
for _ in range(60):
|
||||
time.sleep(1)
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5)
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
pass
|
||||
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||
@@ -558,6 +585,7 @@ class Runtime:
|
||||
attention_reduce_in_fp32: bool = False,
|
||||
random_seed: int = 42,
|
||||
log_level: str = "error",
|
||||
api_key: str = "",
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
@@ -580,6 +608,7 @@ class Runtime:
|
||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
self.url = self.server_args.url()
|
||||
|
||||
@@ -32,6 +32,7 @@ class ServerArgs:
|
||||
enable_flashinfer: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
api_key: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
@@ -201,6 +202,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -88,23 +88,22 @@ class HttpResponse:
|
||||
return self.resp.status
|
||||
|
||||
|
||||
def http_request(url, json=None, stream=False, auth_token=None, verify=None):
|
||||
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
|
||||
"""A faster version of requests.post with low-level urllib API."""
|
||||
headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||
|
||||
# add the Authorization header if an auth token is provided
|
||||
if auth_token is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
# add the API Key header if an API key is provided
|
||||
if api_key is not None:
|
||||
headers["X-API-Key"] = api_key
|
||||
|
||||
if stream:
|
||||
if auth_token is None:
|
||||
return requests.post(url, json=json, stream=True, verify=verify)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authentication": f"Bearer {auth_token}",
|
||||
}
|
||||
return requests.post(
|
||||
url, json=json, stream=True, headers=headers, verify=verify
|
||||
)
|
||||
return requests.post(url, json=json, stream=True, headers=headers)
|
||||
else:
|
||||
req = urllib.request.Request(url)
|
||||
req.add_header("Content-Type", "application/json; charset=utf-8")
|
||||
if auth_token is not None:
|
||||
req.add_header("Authentication", f"Bearer {auth_token}")
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
if json is None:
|
||||
data = None
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user