Add Support for API Key Authentication (#230)

This commit is contained in:
Alessio Dalla Piazza
2024-03-11 13:16:10 +01:00
committed by GitHub
parent 1b35547927
commit d5ae2ebaa2
4 changed files with 63 additions and 18 deletions

View File

@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend): class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, verify=None): def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
super().__init__() super().__init__()
self.support_concate_and_append = True self.support_concate_and_append = True
self.base_url = base_url self.base_url = base_url
self.auth_token = auth_token self.auth_token = auth_token
self.api_key = api_key
self.verify = verify self.verify = verify
res = http_request( res = http_request(
self.base_url + "/get_model_info", self.base_url + "/get_model_info",
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
obj = res.json() obj = res.json()
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json=data, json=data,
stream=True, stream=True,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
pos = 0 pos = 0
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200

View File

@@ -20,6 +20,8 @@ import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert proc_router.is_alive() and proc_detoken.is_alive() assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server(): def _launch_server():
uvicorn.run( uvicorn.run(
app, app,
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
) )
def _wait_and_warmup(): def _wait_and_warmup():
headers = {}
url = server_args.url() url = server_args.url()
for _ in range(60): if server_args.api_key and server_args.api_key != "":
time.sleep(1) headers[API_KEY_HEADER_NAME] = server_args.api_key
for _ in range(120):
time.sleep(0.5)
try: try:
requests.get(url + "/get_model_info", timeout=5) requests.get(url + "/get_model_info", timeout=5, headers=headers)
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
pass pass
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens": 16, "max_new_tokens": 16,
}, },
}, },
headers=headers,
timeout=60, timeout=60,
) )
# print(f"Warmup done. model response: {res.json()['text']}") # print(f"Warmup done. model response: {res.json()['text']}")
@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32: bool = False, attention_reduce_in_fp32: bool = False,
random_seed: int = 42, random_seed: int = 42,
log_level: str = "error", log_level: str = "error",
api_key: str = "",
port: Optional[int] = None, port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None, additional_ports: Optional[Union[List[int], int]] = None,
): ):
@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32=attention_reduce_in_fp32, attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed, random_seed=random_seed,
log_level=log_level, log_level=log_level,
api_key=api_key,
) )
self.url = self.server_args.url() self.url = self.server_args.url()

View File

@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer: bool = False enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
@@ -201,6 +202,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):

View File

@@ -88,23 +88,22 @@ class HttpResponse:
return self.resp.status return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, verify=None): def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
"""A faster version of requests.post with low-level urllib API.""" """A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
# add the Authorization header if an auth token is provided
if auth_token is not None:
headers["Authorization"] = f"Bearer {auth_token}"
# add the API Key header if an API key is provided
if api_key is not None:
headers["X-API-Key"] = api_key
if stream: if stream:
if auth_token is None: return requests.post(url, json=json, stream=True, headers=headers)
return requests.post(url, json=json, stream=True, verify=verify)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}",
}
return requests.post(
url, json=json, stream=True, headers=headers, verify=verify
)
else: else:
req = urllib.request.Request(url) req = urllib.request.Request(url, headers=headers)
req.add_header("Content-Type", "application/json; charset=utf-8")
if auth_token is not None:
req.add_header("Authentication", f"Bearer {auth_token}")
if json is None: if json is None:
data = None data = None
else: else: