Fix RuntimeEndpoint (#279)

This commit is contained in:
Lianmin Zheng
2024-03-11 05:24:24 -07:00
committed by GitHub
parent d5ae2ebaa2
commit 13662fd533
4 changed files with 20 additions and 7 deletions

View File

@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()
def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,

View File

@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
verify: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend):
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
pos = 0

View File

@@ -20,8 +20,6 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
chat_template_name = None

View File

@@ -88,7 +88,9 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
def http_request(
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}