Merged three native APIs into one: get_server_info (#2152)
This commit is contained in:
committed by
GitHub
parent
84a1698d67
commit
dbe1729395
@@ -113,7 +113,7 @@ def main(args):
|
|||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
global_config.default_backend.get_server_args()["tokenizer_path"]
|
global_config.default_backend.get_server_info()["tokenizer_path"]
|
||||||
)
|
)
|
||||||
output_jsons = [state["json_output"] for state in states]
|
output_jsons = [state["json_output"] for state in states]
|
||||||
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
||||||
|
|||||||
@@ -9,13 +9,11 @@
|
|||||||
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
|
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- `/generate` (text generation model)\n",
|
"- `/generate` (text generation model)\n",
|
||||||
"- `/get_server_args`\n",
|
|
||||||
"- `/get_model_info`\n",
|
"- `/get_model_info`\n",
|
||||||
|
"- `/get_server_info`\n",
|
||||||
"- `/health`\n",
|
"- `/health`\n",
|
||||||
"- `/health_generate`\n",
|
"- `/health_generate`\n",
|
||||||
"- `/flush_cache`\n",
|
"- `/flush_cache`\n",
|
||||||
"- `/get_memory_pool_size`\n",
|
|
||||||
"- `/get_max_total_num_tokens`\n",
|
|
||||||
"- `/update_weights`\n",
|
"- `/update_weights`\n",
|
||||||
"- `/encode`(embedding model)\n",
|
"- `/encode`(embedding model)\n",
|
||||||
"- `/classify`(reward model)\n",
|
"- `/classify`(reward model)\n",
|
||||||
@@ -75,26 +73,6 @@
|
|||||||
"print_highlight(response.json())"
|
"print_highlight(response.json())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Get Server Args\n",
|
|
||||||
"Get the arguments of a server."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"url = \"http://localhost:30010/get_server_args\"\n",
|
|
||||||
"\n",
|
|
||||||
"response = requests.get(url)\n",
|
|
||||||
"print_highlight(response.json())"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -123,6 +101,32 @@
|
|||||||
"assert response_json.keys() == {\"model_path\", \"is_generation\"}"
|
"assert response_json.keys() == {\"model_path\", \"is_generation\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Get Server Info\n",
|
||||||
|
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
|
||||||
|
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
|
||||||
|
" - `get_server_args`\n",
|
||||||
|
" - `get_memory_pool_size` \n",
|
||||||
|
" - `get_max_total_num_tokens`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# get_server_info\n",
|
||||||
|
"\n",
|
||||||
|
"url = \"http://localhost:30010/get_server_info\"\n",
|
||||||
|
"\n",
|
||||||
|
"response = requests.get(url)\n",
|
||||||
|
"print_highlight(response.text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -179,52 +183,6 @@
|
|||||||
"print_highlight(response.text)"
|
"print_highlight(response.text)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Get Memory Pool Size\n",
|
|
||||||
"\n",
|
|
||||||
"Get the memory pool size in number of tokens.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# get_memory_pool_size\n",
|
|
||||||
"\n",
|
|
||||||
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
|
|
||||||
"\n",
|
|
||||||
"response = requests.get(url)\n",
|
|
||||||
"print_highlight(response.text)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Get Maximum Total Number of Tokens\n",
|
|
||||||
"\n",
|
|
||||||
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# get_max_total_num_tokens\n",
|
|
||||||
"\n",
|
|
||||||
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
|
|
||||||
"\n",
|
|
||||||
"response = requests.get(url)\n",
|
|
||||||
"print_highlight(response.text)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from sglang.api import (
|
|||||||
gen,
|
gen,
|
||||||
gen_int,
|
gen_int,
|
||||||
gen_string,
|
gen_string,
|
||||||
get_server_args,
|
get_server_info,
|
||||||
image,
|
image,
|
||||||
select,
|
select,
|
||||||
set_default_backend,
|
set_default_backend,
|
||||||
@@ -41,7 +41,7 @@ __all__ = [
|
|||||||
"gen",
|
"gen",
|
||||||
"gen_int",
|
"gen_int",
|
||||||
"gen_string",
|
"gen_string",
|
||||||
"get_server_args",
|
"get_server_info",
|
||||||
"image",
|
"image",
|
||||||
"select",
|
"select",
|
||||||
"set_default_backend",
|
"set_default_backend",
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
|
|||||||
return backend.flush_cache()
|
return backend.flush_cache()
|
||||||
|
|
||||||
|
|
||||||
def get_server_args(backend: Optional[BaseBackend] = None):
|
def get_server_info(backend: Optional[BaseBackend] = None):
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
if backend is None:
|
if backend is None:
|
||||||
return None
|
return None
|
||||||
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
|
|||||||
# If backend is Runtime
|
# If backend is Runtime
|
||||||
if hasattr(backend, "endpoint"):
|
if hasattr(backend, "endpoint"):
|
||||||
backend = backend.endpoint
|
backend = backend.endpoint
|
||||||
return backend.get_server_args()
|
return backend.get_server_info()
|
||||||
|
|
||||||
|
|
||||||
def gen(
|
def gen(
|
||||||
|
|||||||
@@ -78,5 +78,5 @@ class BaseBackend:
|
|||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_server_args(self):
|
def get_server_info(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
)
|
)
|
||||||
self._assert_success(res)
|
self._assert_success(res)
|
||||||
|
|
||||||
def get_server_args(self):
|
def get_server_info(self):
|
||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/get_server_args",
|
self.base_url + "/get_server_info",
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -146,10 +146,15 @@ async def get_model_info():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_server_args")
|
@app.get("/get_server_info")
|
||||||
async def get_server_args():
|
async def get_server_info():
|
||||||
"""Get the server arguments."""
|
try:
|
||||||
return dataclasses.asdict(tokenizer_manager.server_args)
|
return await _get_server_info()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
@app.post("/flush_cache")
|
||||||
@@ -185,30 +190,6 @@ async def stop_profile():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_max_total_num_tokens")
|
|
||||||
async def get_max_total_num_tokens():
|
|
||||||
try:
|
|
||||||
return {"max_total_num_tokens": _get_max_total_num_tokens()}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return ORJSONResponse(
|
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
|
||||||
async def get_memory_pool_size():
|
|
||||||
"""Get the memory pool size in number of tokens"""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.get_memory_pool_size()
|
|
||||||
|
|
||||||
return ret
|
|
||||||
except Exception as e:
|
|
||||||
return ORJSONResponse(
|
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights")
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
@@ -542,8 +523,12 @@ def launch_server(
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
def _get_max_total_num_tokens():
|
async def _get_server_info():
|
||||||
return _max_total_num_tokens
|
return {
|
||||||
|
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||||
|
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
||||||
|
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
@@ -787,14 +772,16 @@ class Runtime:
|
|||||||
response = requests.post(self.url + "/encode", json=json_data)
|
response = requests.post(self.url + "/encode", json=json_data)
|
||||||
return json.dumps(response.json())
|
return json.dumps(response.json())
|
||||||
|
|
||||||
def get_max_total_num_tokens(self):
|
async def get_server_info(self):
|
||||||
response = requests.get(f"{self.url}/get_max_total_num_tokens")
|
async with aiohttp.ClientSession() as session:
|
||||||
if response.status_code == 200:
|
async with session.get(f"{self.url}/get_server_info") as response:
|
||||||
return response.json()["max_total_num_tokens"]
|
if response.status == 200:
|
||||||
else:
|
return await response.json()
|
||||||
raise RuntimeError(
|
else:
|
||||||
f"Failed to get max tokens. {response.json()['error']['message']}"
|
error_data = await response.json()
|
||||||
)
|
raise RuntimeError(
|
||||||
|
f"Failed to get server info. {error_data['error']['message']}"
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
@@ -946,5 +933,5 @@ class Engine:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(encode_request(obj, None))
|
return loop.run_until_complete(encode_request(obj, None))
|
||||||
|
|
||||||
def get_max_total_num_tokens(self):
|
async def get_server_info(self):
|
||||||
return _get_max_total_num_tokens()
|
return await _get_server_info()
|
||||||
|
|||||||
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
|
|||||||
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
|
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_server_args")]
|
#[get("/get_server_info")]
|
||||||
async fn get_server_args(data: web::Data<AppState>) -> impl Responder {
|
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
|
||||||
let worker_url = match data.router.get_first() {
|
let worker_url = match data.router.get_first() {
|
||||||
Some(url) => url,
|
Some(url) => url,
|
||||||
None => return HttpResponse::InternalServerError().finish(),
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
};
|
};
|
||||||
|
|
||||||
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await
|
forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/v1/models")]
|
#[get("/v1/models")]
|
||||||
@@ -153,7 +153,7 @@ pub async fn startup(
|
|||||||
.service(get_model_info)
|
.service(get_model_info)
|
||||||
.service(health)
|
.service(health)
|
||||||
.service(health_generate)
|
.service(health_generate)
|
||||||
.service(get_server_args)
|
.service(get_server_info)
|
||||||
})
|
})
|
||||||
.bind((host, port))?
|
.bind((host, port))?
|
||||||
.run()
|
.run()
|
||||||
|
|||||||
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
def test_get_memory_pool_size(self):
|
def test_get_memory_pool_size(self):
|
||||||
response = requests.get(self.base_url + "/get_memory_pool_size")
|
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
|
||||||
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
response = requests.get(self.base_url + "/get_memory_pool_size")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
|
||||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
|
||||||
|
|
||||||
def test_get_memory_pool_size(self):
|
def test_get_server_info(self):
|
||||||
response = requests.post(self.base_url + "/get_memory_pool_size")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
self.assertIsInstance(response.json(), int)
|
response_json = response.json()
|
||||||
|
|
||||||
|
max_total_num_tokens = response_json["max_total_num_tokens"]
|
||||||
|
self.assertIsInstance(max_total_num_tokens, int)
|
||||||
|
|
||||||
|
memory_pool_size = response_json["memory_pool_size"]
|
||||||
|
self.assertIsInstance(memory_pool_size, int)
|
||||||
|
|
||||||
|
attention_backend = response_json["attention_backend"]
|
||||||
|
self.assertIsInstance(attention_backend, str)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user