From 13662fd5336fc8428e130567fdb1695d664eea24 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 11 Mar 2024 05:24:24 -0700 Subject: [PATCH] Fix RuntimeEndpoint (#279) --- python/sglang/api.py | 3 +++ python/sglang/backend/runtime_endpoint.py | 14 ++++++++++---- python/sglang/srt/server.py | 6 ++++-- python/sglang/utils.py | 4 +++- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/python/sglang/api.py b/python/sglang/api.py index f1337a67e..9470b1425 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -43,18 +43,21 @@ def Runtime(*args, **kwargs): def set_default_backend(backend: BaseBackend): global_config.default_backend = backend + def flush_cache(backend: BaseBackend = None): backend = backend or global_config.default_backend if backend is None: return False return backend.flush_cache() + def get_server_args(backend: BaseBackend = None): backend = backend or global_config.default_backend if backend is None: return None return backend.get_server_args() + def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index fc5774ef5..3d2ecaa76 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request class RuntimeEndpoint(BaseBackend): - def __init__(self, base_url, auth_token=None, api_key=None, verify=None): + def __init__( + self, + base_url: str, + auth_token: Optional[str] = None, + api_key: Optional[str] = None, + verify: Optional[str] = None, + ): super().__init__() self.support_concate_and_append = True @@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, - api_key=self.api_key + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, - api_key=self.api_key + api_key=self.api_key, verify=self.verify, ) assert res.status_code == 200 @@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend): json=data, stream=True, auth_token=self.auth_token, - api_key=self.api_key + api_key=self.api_key, verify=self.verify, ) pos = 0 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 3d853bd92..7971c6e3a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -20,8 +20,6 @@ import requests import uvicorn import uvloop from fastapi import FastAPI, HTTPException, Request -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel from sglang.backend.runtime_endpoint import RuntimeEndpoint @@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import handle_port_init +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) API_KEY_HEADER_NAME = "X-API-Key" + class APIKeyValidatorMiddleware(BaseHTTPMiddleware): def __init__(self, app, api_key: str): super().__init__(app) @@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): response = await call_next(request) return response + app = FastAPI() tokenizer_manager = None chat_template_name = None diff --git a/python/sglang/utils.py b/python/sglang/utils.py index e7638e6a4..aa993d8f8 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -88,7 +88,9 @@ class HttpResponse: return self.resp.status -def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None): +def http_request( + url, json=None, stream=False, auth_token=None, api_key=None, verify=None +): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"}