Fix RuntimeEndpoint (#279)
This commit is contained in:
@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
|
||||
def flush_cache(backend: BaseBackend = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return False
|
||||
return backend.flush_cache()
|
||||
|
||||
|
||||
def get_server_args(backend: BaseBackend = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return None
|
||||
return backend.get_server_args()
|
||||
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
|
||||
@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
json=data,
|
||||
stream=True,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
pos = 0
|
||||
|
||||
@@ -20,8 +20,6 @@ import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import handle_port_init
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
chat_template_name = None
|
||||
|
||||
@@ -88,7 +88,9 @@ class HttpResponse:
|
||||
return self.resp.status
|
||||
|
||||
|
||||
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
|
||||
def http_request(
|
||||
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
|
||||
):
|
||||
"""A faster version of requests.post with low-level urllib API."""
|
||||
headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user