Merged three native APIs into one: get_server_info (#2152)
This commit is contained in:
committed by
GitHub
parent
84a1698d67
commit
dbe1729395
@@ -11,7 +11,7 @@ from sglang.api import (
|
||||
gen,
|
||||
gen_int,
|
||||
gen_string,
|
||||
get_server_args,
|
||||
get_server_info,
|
||||
image,
|
||||
select,
|
||||
set_default_backend,
|
||||
@@ -41,7 +41,7 @@ __all__ = [
|
||||
"gen",
|
||||
"gen_int",
|
||||
"gen_string",
|
||||
"get_server_args",
|
||||
"get_server_info",
|
||||
"image",
|
||||
"select",
|
||||
"set_default_backend",
|
||||
|
||||
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
|
||||
return backend.flush_cache()
|
||||
|
||||
|
||||
def get_server_args(backend: Optional[BaseBackend] = None):
|
||||
def get_server_info(backend: Optional[BaseBackend] = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return None
|
||||
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
|
||||
# If backend is Runtime
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
return backend.get_server_args()
|
||||
return backend.get_server_info()
|
||||
|
||||
|
||||
def gen(
|
||||
|
||||
@@ -78,5 +78,5 @@ class BaseBackend:
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
def get_server_args(self):
|
||||
def get_server_info(self):
|
||||
pass
|
||||
|
||||
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def get_server_args(self):
|
||||
def get_server_info(self):
|
||||
res = http_request(
|
||||
self.base_url + "/get_server_args",
|
||||
self.base_url + "/get_server_info",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
|
||||
@@ -146,10 +146,15 @@ async def get_model_info():
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/get_server_args")
|
||||
async def get_server_args():
|
||||
"""Get the server arguments."""
|
||||
return dataclasses.asdict(tokenizer_manager.server_args)
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
try:
|
||||
return await _get_server_info()
|
||||
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
@@ -185,30 +190,6 @@ async def stop_profile():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/get_max_total_num_tokens")
|
||||
async def get_max_total_num_tokens():
|
||||
try:
|
||||
return {"max_total_num_tokens": _get_max_total_num_tokens()}
|
||||
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
||||
async def get_memory_pool_size():
|
||||
"""Get the memory pool size in number of tokens"""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_memory_pool_size()
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights")
|
||||
@time_func_latency
|
||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||
@@ -542,8 +523,12 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
def _get_max_total_num_tokens():
|
||||
return _max_total_num_tokens
|
||||
async def _get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
||||
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
||||
}
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
@@ -787,14 +772,16 @@ class Runtime:
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def get_max_total_num_tokens(self):
|
||||
response = requests.get(f"{self.url}/get_max_total_num_tokens")
|
||||
if response.status_code == 200:
|
||||
return response.json()["max_total_num_tokens"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed to get max tokens. {response.json()['error']['message']}"
|
||||
)
|
||||
async def get_server_info(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.url}/get_server_info") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_data = await response.json()
|
||||
raise RuntimeError(
|
||||
f"Failed to get server info. {error_data['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
@@ -946,5 +933,5 @@ class Engine:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def get_max_total_num_tokens(self):
|
||||
return _get_max_total_num_tokens()
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
Reference in New Issue
Block a user