Expose max total num tokens from Runtime & Engine API (#2092)
This commit is contained in:
committed by
GitHub
parent
72f87b723b
commit
c35cd1f8c7
@@ -15,6 +15,7 @@
|
|||||||
"- `/health_generate`\n",
|
"- `/health_generate`\n",
|
||||||
"- `/flush_cache`\n",
|
"- `/flush_cache`\n",
|
||||||
"- `/get_memory_pool_size`\n",
|
"- `/get_memory_pool_size`\n",
|
||||||
|
"- `/get_max_total_num_tokens`\n",
|
||||||
"- `/update_weights`\n",
|
"- `/update_weights`\n",
|
||||||
"- `/encode`(embedding model)\n",
|
"- `/encode`(embedding model)\n",
|
||||||
"- `/classify`(reward model)\n",
|
"- `/classify`(reward model)\n",
|
||||||
@@ -201,6 +202,29 @@
|
|||||||
"print_highlight(response.text)"
|
"print_highlight(response.text)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Get Maximum Total Number of Tokens\n",
|
||||||
|
"\n",
|
||||||
|
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# get_max_total_num_tokens\n",
|
||||||
|
"\n",
|
||||||
|
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
|
||||||
|
"\n",
|
||||||
|
"response = requests.get(url)\n",
|
||||||
|
"print_highlight(response.text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|||||||
@@ -167,9 +167,12 @@ class DataParallelController:
|
|||||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for model to finish loading
|
# Wait for model to finish loading and get max token nums
|
||||||
|
scheduler_info = []
|
||||||
for i in range(len(scheduler_pipe_readers)):
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
scheduler_pipe_readers[i].recv()
|
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
||||||
|
|
||||||
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||||
|
|
||||||
return send_to
|
return send_to
|
||||||
|
|
||||||
@@ -191,7 +194,10 @@ class DataParallelController:
|
|||||||
send_to = get_zmq_socket(
|
send_to = get_zmq_socket(
|
||||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||||
)
|
)
|
||||||
reader.recv()
|
|
||||||
|
scheduler_info = reader.recv()
|
||||||
|
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
||||||
|
|
||||||
return send_to
|
return send_to
|
||||||
|
|
||||||
def round_robin_scheduler(self, req):
|
def round_robin_scheduler(self, req):
|
||||||
@@ -233,7 +239,9 @@ def run_data_parallel_controller_process(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
controller = DataParallelController(server_args, port_args)
|
controller = DataParallelController(server_args, port_args)
|
||||||
pipe_writer.send("ready")
|
pipe_writer.send(
|
||||||
|
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
||||||
|
)
|
||||||
controller.event_loop()
|
controller.event_loop()
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = get_exception_traceback()
|
msg = get_exception_traceback()
|
||||||
|
|||||||
@@ -1400,7 +1400,9 @@ def run_scheduler_process(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||||
pipe_writer.send("ready")
|
pipe_writer.send(
|
||||||
|
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
||||||
|
)
|
||||||
if scheduler.enable_overlap:
|
if scheduler.enable_overlap:
|
||||||
scheduler.event_loop_overlap()
|
scheduler.event_loop_overlap()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_manager: TokenizerManager = None
|
tokenizer_manager: TokenizerManager = None
|
||||||
|
_max_total_num_tokens = None
|
||||||
|
|
||||||
##### Native API endpoints #####
|
##### Native API endpoints #####
|
||||||
|
|
||||||
@@ -184,6 +185,17 @@ async def stop_profile():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_max_total_num_tokens")
|
||||||
|
async def get_max_total_num_tokens():
|
||||||
|
try:
|
||||||
|
return {"max_total_num_tokens": _get_max_total_num_tokens()}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
||||||
async def get_memory_pool_size():
|
async def get_memory_pool_size():
|
||||||
"""Get the memory pool size in number of tokens"""
|
"""Get the memory pool size in number of tokens"""
|
||||||
@@ -390,6 +402,7 @@ def launch_engine(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
global _max_total_num_tokens
|
||||||
|
|
||||||
# Configure global environment
|
# Configure global environment
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
@@ -455,9 +468,20 @@ def launch_engine(
|
|||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
|
|
||||||
# Wait for model to finish loading
|
# Wait for model to finish loading & get max token nums
|
||||||
|
scheduler_info = []
|
||||||
for i in range(len(scheduler_pipe_readers)):
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
scheduler_pipe_readers[i].recv()
|
data = scheduler_pipe_readers[i].recv()
|
||||||
|
|
||||||
|
if data["status"] != "ready":
|
||||||
|
self.shutdown()
|
||||||
|
raise RuntimeError(
|
||||||
|
"Initialization failed. Please see the error messages above."
|
||||||
|
)
|
||||||
|
scheduler_info.append(data)
|
||||||
|
|
||||||
|
# Assume all schedulers have same max_total_num_tokens
|
||||||
|
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
@@ -518,6 +542,10 @@ def launch_server(
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_total_num_tokens():
|
||||||
|
return _max_total_num_tokens
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
@@ -759,6 +787,15 @@ class Runtime:
|
|||||||
response = requests.post(self.url + "/encode", json=json_data)
|
response = requests.post(self.url + "/encode", json=json_data)
|
||||||
return json.dumps(response.json())
|
return json.dumps(response.json())
|
||||||
|
|
||||||
|
def get_max_total_num_tokens(self):
|
||||||
|
response = requests.get(f"{self.url}/get_max_total_num_tokens")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()["max_total_num_tokens"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to get max tokens. {response.json()['error']['message']}"
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|
||||||
@@ -908,3 +945,6 @@ class Engine:
|
|||||||
# get the current event loop
|
# get the current event loop
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(encode_request(obj, None))
|
return loop.run_until_complete(encode_request(obj, None))
|
||||||
|
|
||||||
|
def get_max_total_num_tokens(self):
|
||||||
|
return _get_max_total_num_tokens()
|
||||||
|
|||||||
Reference in New Issue
Block a user