Fix RuntimeEndpoint (#279)
This commit is contained in:
@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
|
|||||||
def set_default_backend(backend: BaseBackend):
|
def set_default_backend(backend: BaseBackend):
|
||||||
global_config.default_backend = backend
|
global_config.default_backend = backend
|
||||||
|
|
||||||
|
|
||||||
def flush_cache(backend: BaseBackend = None):
|
def flush_cache(backend: BaseBackend = None):
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
if backend is None:
|
if backend is None:
|
||||||
return False
|
return False
|
||||||
return backend.flush_cache()
|
return backend.flush_cache()
|
||||||
|
|
||||||
|
|
||||||
def get_server_args(backend: BaseBackend = None):
|
def get_server_args(backend: BaseBackend = None):
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
if backend is None:
|
if backend is None:
|
||||||
return None
|
return None
|
||||||
return backend.get_server_args()
|
return backend.get_server_args()
|
||||||
|
|
||||||
|
|
||||||
def gen(
|
def gen(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeEndpoint(BaseBackend):
|
class RuntimeEndpoint(BaseBackend):
|
||||||
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
verify: Optional[str] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.support_concate_and_append = True
|
self.support_concate_and_append = True
|
||||||
|
|
||||||
@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
api_key=self.api_key
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
api_key=self.api_key
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
api_key=self.api_key
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
pos = 0
|
pos = 0
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ import requests
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process
|
|||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import handle_port_init
|
from sglang.srt.utils import handle_port_init
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
|
||||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||||
def __init__(self, app, api_key: str):
|
def __init__(self, app, api_key: str):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
tokenizer_manager = None
|
tokenizer_manager = None
|
||||||
chat_template_name = None
|
chat_template_name = None
|
||||||
|
|||||||
@@ -88,7 +88,9 @@ class HttpResponse:
|
|||||||
return self.resp.status
|
return self.resp.status
|
||||||
|
|
||||||
|
|
||||||
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
|
def http_request(
|
||||||
|
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
|
||||||
|
):
|
||||||
"""A faster version of requests.post with low-level urllib API."""
|
"""A faster version of requests.post with low-level urllib API."""
|
||||||
headers = {"Content-Type": "application/json; charset=utf-8"}
|
headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user