Add Support for API Key Authentication (#230)
This commit is contained in:
committed by
GitHub
parent
1b35547927
commit
d5ae2ebaa2
@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeEndpoint(BaseBackend):
|
class RuntimeEndpoint(BaseBackend):
|
||||||
def __init__(self, base_url, auth_token=None, verify=None):
|
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.support_concate_and_append = True
|
self.support_concate_and_append = True
|
||||||
|
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.auth_token = auth_token
|
self.auth_token = auth_token
|
||||||
|
self.api_key = api_key
|
||||||
self.verify = verify
|
self.verify = verify
|
||||||
|
|
||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/get_model_info",
|
self.base_url + "/get_model_info",
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
pos = 0
|
pos = 0
|
||||||
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/concate_and_append_request",
|
self.base_url + "/concate_and_append_request",
|
||||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import requests
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
|
|||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app, api_key: str):
|
||||||
|
super().__init__(app)
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# extract API key from the request headers
|
||||||
|
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||||
|
if not api_key_header or api_key_header != self.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Invalid API Key"},
|
||||||
|
)
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
tokenizer_manager = None
|
tokenizer_manager = None
|
||||||
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
|
|
||||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||||
|
|
||||||
|
if server_args.api_key and server_args.api_key != "":
|
||||||
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||||
|
|
||||||
def _launch_server():
|
def _launch_server():
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _wait_and_warmup():
|
def _wait_and_warmup():
|
||||||
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
for _ in range(60):
|
if server_args.api_key and server_args.api_key != "":
|
||||||
time.sleep(1)
|
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||||
|
|
||||||
|
for _ in range(120):
|
||||||
|
time.sleep(0.5)
|
||||||
try:
|
try:
|
||||||
requests.get(url + "/get_model_info", timeout=5)
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
pass
|
pass
|
||||||
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
"max_new_tokens": 16,
|
"max_new_tokens": 16,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
headers=headers,
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
# print(f"Warmup done. model response: {res.json()['text']}")
|
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||||
@@ -558,6 +585,7 @@ class Runtime:
|
|||||||
attention_reduce_in_fp32: bool = False,
|
attention_reduce_in_fp32: bool = False,
|
||||||
random_seed: int = 42,
|
random_seed: int = 42,
|
||||||
log_level: str = "error",
|
log_level: str = "error",
|
||||||
|
api_key: str = "",
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
additional_ports: Optional[Union[List[int], int]] = None,
|
additional_ports: Optional[Union[List[int], int]] = None,
|
||||||
):
|
):
|
||||||
@@ -580,6 +608,7 @@ class Runtime:
|
|||||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.url = self.server_args.url()
|
self.url = self.server_args.url()
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class ServerArgs:
|
|||||||
enable_flashinfer: bool = False
|
enable_flashinfer: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
|
api_key: str = ""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -201,6 +202,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.api_key,
|
||||||
|
help="Set API Key",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
@@ -88,23 +88,22 @@ class HttpResponse:
|
|||||||
return self.resp.status
|
return self.resp.status
|
||||||
|
|
||||||
|
|
||||||
def http_request(url, json=None, stream=False, auth_token=None, verify=None):
|
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
|
||||||
"""A faster version of requests.post with low-level urllib API."""
|
"""A faster version of requests.post with low-level urllib API."""
|
||||||
if stream:
|
headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||||
if auth_token is None:
|
|
||||||
return requests.post(url, json=json, stream=True, verify=verify)
|
# add the Authorization header if an auth token is provided
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authentication": f"Bearer {auth_token}",
|
|
||||||
}
|
|
||||||
return requests.post(
|
|
||||||
url, json=json, stream=True, headers=headers, verify=verify
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
req = urllib.request.Request(url)
|
|
||||||
req.add_header("Content-Type", "application/json; charset=utf-8")
|
|
||||||
if auth_token is not None:
|
if auth_token is not None:
|
||||||
req.add_header("Authentication", f"Bearer {auth_token}")
|
headers["Authorization"] = f"Bearer {auth_token}"
|
||||||
|
|
||||||
|
# add the API Key header if an API key is provided
|
||||||
|
if api_key is not None:
|
||||||
|
headers["X-API-Key"] = api_key
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return requests.post(url, json=json, stream=True, headers=headers)
|
||||||
|
else:
|
||||||
|
req = urllib.request.Request(url, headers=headers)
|
||||||
if json is None:
|
if json is None:
|
||||||
data = None
|
data = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user