Merged three native APIs into one: get_server_info (#2152)

This commit is contained in:
Henry Hyeonmok Ko
2024-11-24 01:37:58 -08:00
committed by GitHub
parent 84a1698d67
commit dbe1729395
10 changed files with 81 additions and 126 deletions

View File

@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy # Compute accuracy
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
global_config.default_backend.get_server_args()["tokenizer_path"] global_config.default_backend.get_server_info()["tokenizer_path"]
) )
output_jsons = [state["json_output"] for state in states] output_jsons = [state["json_output"] for state in states]
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons) num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)

View File

@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n", "Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n", "\n",
"- `/generate` (text generation model)\n", "- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n", "- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n", "- `/health`\n",
"- `/health_generate`\n", "- `/health_generate`\n",
"- `/flush_cache`\n", "- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n", "- `/update_weights`\n",
"- `/encode`(embedding model)\n", "- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n", "- `/classify`(reward model)\n",
@@ -75,26 +73,6 @@
"print_highlight(response.json())" "print_highlight(response.json())"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@@ -123,6 +101,32 @@
"assert response_json.keys() == {\"model_path\", \"is_generation\"}" "assert response_json.keys() == {\"model_path\", \"is_generation\"}"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Info\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_server_info\n",
"\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@@ -179,52 +183,6 @@
"print_highlight(response.text)" "print_highlight(response.text)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_memory_pool_size\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Maximum Total Number of Tokens\n",
"\n",
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_max_total_num_tokens\n",
"\n",
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},

View File

@@ -11,7 +11,7 @@ from sglang.api import (
gen, gen,
gen_int, gen_int,
gen_string, gen_string,
get_server_args, get_server_info,
image, image,
select, select,
set_default_backend, set_default_backend,
@@ -41,7 +41,7 @@ __all__ = [
"gen", "gen",
"gen_int", "gen_int",
"gen_string", "gen_string",
"get_server_args", "get_server_info",
"image", "image",
"select", "select",
"set_default_backend", "set_default_backend",

View File

@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return backend.flush_cache() return backend.flush_cache()
def get_server_args(backend: Optional[BaseBackend] = None): def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
if backend is None: if backend is None:
return None return None
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime # If backend is Runtime
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
return backend.get_server_args() return backend.get_server_info()
def gen( def gen(

View File

@@ -78,5 +78,5 @@ class BaseBackend:
def flush_cache(self): def flush_cache(self):
pass pass
def get_server_args(self): def get_server_info(self):
pass pass

View File

@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
) )
self._assert_success(res) self._assert_success(res)
def get_server_args(self): def get_server_info(self):
res = http_request( res = http_request(
self.base_url + "/get_server_args", self.base_url + "/get_server_info",
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )

View File

@@ -146,10 +146,15 @@ async def get_model_info():
return result return result
@app.get("/get_server_args") @app.get("/get_server_info")
async def get_server_args(): async def get_server_info():
"""Get the server arguments.""" try:
return dataclasses.asdict(tokenizer_manager.server_args) return await _get_server_info()
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/flush_cache") @app.post("/flush_cache")
@@ -185,30 +190,6 @@ async def stop_profile():
) )
@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights") @app.post("/update_weights")
@time_func_latency @time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
@@ -542,8 +523,12 @@ def launch_server(
t.join() t.join()
def _get_max_total_num_tokens(): async def _get_server_info():
return _max_total_num_tokens return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
}
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
@@ -787,14 +772,16 @@ class Runtime:
response = requests.post(self.url + "/encode", json=json_data) response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json()) return json.dumps(response.json())
def get_max_total_num_tokens(self): async def get_server_info(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens") async with aiohttp.ClientSession() as session:
if response.status_code == 200: async with session.get(f"{self.url}/get_server_info") as response:
return response.json()["max_total_num_tokens"] if response.status == 200:
else: return await response.json()
raise RuntimeError( else:
f"Failed to get max tokens. {response.json()['error']['message']}" error_data = await response.json()
) raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
@@ -946,5 +933,5 @@ class Engine:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None)) return loop.run_until_complete(encode_request(obj, None))
def get_max_total_num_tokens(self): async def get_server_info(self):
return _get_max_total_num_tokens() return await _get_server_info()

View File

@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request(&data.client, worker_url, "/health_generate".to_string()).await forward_request(&data.client, worker_url, "/health_generate".to_string()).await
} }
#[get("/get_server_args")] #[get("/get_server_info")]
async fn get_server_args(data: web::Data<AppState>) -> impl Responder { async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() { let worker_url = match data.router.get_first() {
Some(url) => url, Some(url) => url,
None => return HttpResponse::InternalServerError().finish(), None => return HttpResponse::InternalServerError().finish(),
}; };
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
} }
#[get("/v1/models")] #[get("/v1/models")]
@@ -153,7 +153,7 @@ pub async fn startup(
.service(get_model_info) .service(get_model_info)
.service(health) .service(health)
.service(health_generate) .service(health_generate)
.service(get_server_args) .service(get_server_info)
}) })
.bind((host, port))? .bind((host, port))?
.run() .run()

View File

@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert response.status_code == 200 assert response.status_code == 200
def test_get_memory_pool_size(self): def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size") # use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200 assert response.status_code == 200
time.sleep(5) time.sleep(5)
response = requests.get(self.base_url + "/get_memory_pool_size") response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_get_memory_pool_size(self): def test_get_server_info(self):
response = requests.post(self.base_url + "/get_memory_pool_size") response = requests.get(self.base_url + "/get_server_info")
self.assertIsInstance(response.json(), int) response_json = response.json()
max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)
memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)
attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str)
if __name__ == "__main__": if __name__ == "__main__":