diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index daaf52660..4a27d1f7f 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -15,6 +15,7 @@ "- `/health_generate`\n", "- `/flush_cache`\n", "- `/get_memory_pool_size`\n", + "- `/get_max_total_num_tokens`\n", "- `/update_weights`\n", "- `/encode`(embedding model)\n", "- `/classify`(reward model)\n", @@ -201,6 +202,29 @@ "print_highlight(response.text)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get Maximum Total Number of Tokens\n", + "\n", + "Exposes the maximum number of tokens SGLang can handle based on the current configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get_max_total_num_tokens\n", + "\n", + "url = \"http://localhost:30010/get_max_total_num_tokens\"\n", + "\n", + "response = requests.get(url)\n", + "print_highlight(response.text)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 8ba65c70c..d4730e3f7 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -167,9 +167,12 @@ class DataParallelController: self.context, zmq.PUSH, port_args.scheduler_input_ipc_name ) - # Wait for model to finish loading + # Wait for model to finish loading and get max token nums + scheduler_info = [] for i in range(len(scheduler_pipe_readers)): - scheduler_pipe_readers[i].recv() + scheduler_info.append(scheduler_pipe_readers[i].recv()) + + self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] return send_to @@ -191,7 +194,10 @@ class DataParallelController: send_to = get_zmq_socket( self.context, zmq.PUSH, port_args.scheduler_input_ipc_name ) - reader.recv() + + scheduler_info = reader.recv() + self.max_total_num_tokens = scheduler_info["max_total_num_tokens"] + return send_to def round_robin_scheduler(self, req): @@ -233,7 +239,9 @@ def run_data_parallel_controller_process( try: controller = DataParallelController(server_args, port_args) - pipe_writer.send("ready") + pipe_writer.send( + {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} + ) controller.event_loop() except Exception: msg = get_exception_traceback() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 792692ed6..de3c753ef 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1400,7 +1400,9 @@ def run_scheduler_process( try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) - pipe_writer.send("ready") + pipe_writer.send( + {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens} + ) if scheduler.enable_overlap: scheduler.event_loop_overlap() else: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 55fa3a6ea..c2aa73e36 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -102,6 +102,7 @@ app.add_middleware( ) tokenizer_manager: TokenizerManager = None +_max_total_num_tokens = None ##### Native API endpoints ##### @@ -184,6 +185,17 @@ async def stop_profile(): ) +@app.get("/get_max_total_num_tokens") +async def get_max_total_num_tokens(): + try: + return {"max_total_num_tokens": _get_max_total_num_tokens()} + + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) async def get_memory_pool_size(): """Get the memory pool size in number of tokens""" @@ -390,6 +402,7 @@ def launch_engine( """ global tokenizer_manager + global _max_total_num_tokens # Configure global environment configure_logger(server_args) @@ -455,9 +468,20 @@ def launch_engine( if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - # Wait for model to finish loading + # Wait for model to finish loading & get max token nums + scheduler_info = [] for i in range(len(scheduler_pipe_readers)): - scheduler_pipe_readers[i].recv() + data = scheduler_pipe_readers[i].recv() + + if data["status"] != "ready": + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_info.append(data) + + # Assume all schedulers have same max_total_num_tokens + _max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] def launch_server( @@ -518,6 +542,10 @@ def launch_server( t.join() +def _get_max_total_num_tokens(): + return _max_total_num_tokens + + def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -759,6 +787,15 @@ class Runtime: response = requests.post(self.url + "/encode", json=json_data) return json.dumps(response.json()) + def get_max_total_num_tokens(self): + response = requests.get(f"{self.url}/get_max_total_num_tokens") + if response.status_code == 200: + return response.json()["max_total_num_tokens"] + else: + raise RuntimeError( + f"Failed to get max tokens. {response.json()['error']['message']}" + ) + def __del__(self): self.shutdown() @@ -908,3 +945,6 @@ class Engine: # get the current event loop loop = asyncio.get_event_loop() return loop.run_until_complete(encode_request(obj, None)) + + def get_max_total_num_tokens(self): + return _get_max_total_num_tokens()