diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fa8c8d174..020e96e65 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -22,7 +22,7 @@ import signal import sys import time import uuid -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import fastapi import uvloop @@ -173,6 +173,15 @@ class TokenizerManager: # Others self.gracefully_exit = False + self.init_weights_update_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_distributed_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) # Metrics if self.enable_metrics: @@ -190,8 +199,7 @@ class TokenizerManager: ): created_time = time.time() - if self.to_create_loop: - self.create_handle_loop() + self.auto_create_handle_loop() if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( @@ -440,8 +448,7 @@ class TokenizerManager: obj: UpdateWeightFromDiskReqInput, request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: - if self.to_create_loop: - self.create_handle_loop() + self.auto_create_handle_loop() # default the load format to the server_args if obj.load_format is None: @@ -456,7 +463,7 @@ class TokenizerManager: async def _wait_for_model_update_from_disk( self, obj: UpdateWeightFromDiskReqInput - ) -> Tuple[bool, str, int]: + ) -> Tuple[bool, str]: self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future() if self.server_args.dp_size == 1: @@ -485,15 +492,11 @@ class TokenizerManager: obj: InitWeightsUpdateGroupReqInput, request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: - if self.to_create_loop: - self.create_handle_loop() - self.send_to_scheduler.send_pyobj(obj) - - self.init_weights_update_group_result = asyncio.Future() + self.auto_create_handle_loop() assert ( self.server_args.dp_size == 1 ), "dp_size must be 1 for init parameter update group" - result = await self.init_weights_update_group_result + result = (await self.init_weights_update_group_communicator(obj))[0] return result.success, result.message async def update_weights_from_distributed( @@ -501,44 +504,32 @@ class TokenizerManager: obj: UpdateWeightsFromDistributedReqInput, request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: - if self.to_create_loop: - self.create_handle_loop() + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" # This means that weight sync # cannot run while requests are in progress. async with self.model_update_lock.writer_lock: - self.send_to_scheduler.send_pyobj(obj) - self.parameter_update_result: Awaitable[ - UpdateWeightsFromDistributedReqOutput - ] = asyncio.Future() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be for update weights from distributed" - result = await self.parameter_update_result + result = (await self.update_weights_from_distributed_communicator(obj))[0] return result.success, result.message async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None ): - if self.to_create_loop: - self.create_handle_loop() - - self.send_to_scheduler.send_pyobj(obj) - self.get_weights_by_name_result = asyncio.Future() + self.auto_create_handle_loop() + results = await self.get_weights_by_name_communicator(obj) + all_parameters = [r.parameter for r in results] if self.server_args.dp_size == 1: - result = await self.get_weights_by_name_result - return result.parameter + return all_parameters[0] else: - self.get_weights_by_name_tmp = [] - result = await self.get_weights_by_name_result - all_parameters = [r.parameter for r in result] return all_parameters async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): - if self.to_create_loop: - self.create_handle_loop() + self.auto_create_handle_loop() session_id = uuid.uuid4().hex obj.session_id = session_id @@ -568,7 +559,7 @@ class TokenizerManager: background_tasks.add_task(abort_request) return background_tasks - def create_handle_loop(self): + def auto_create_handle_loop(self): if not self.to_create_loop: return @@ -711,21 +702,14 @@ class TokenizerManager: assert ( self.server_args.dp_size == 1 ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_result.set_result(recv_obj) + self.init_weights_update_group_communicator.handle_recv(recv_obj) elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): assert ( self.server_args.dp_size == 1 ), "dp_size must be 1 for update weights from distributed" - self.parameter_update_result.set_result(recv_obj) + self.update_weights_from_distributed_communicator.handle_recv(recv_obj) elif isinstance(recv_obj, GetWeightsByNameReqOutput): - if self.server_args.dp_size == 1: - self.get_weights_by_name_result.set_result(recv_obj) - else: - self.get_weights_by_name_tmp.append(recv_obj) - if len(self.get_weights_by_name_tmp) == self.server_args.dp_size: - self.get_weights_by_name_result.set_result( - self.get_weights_by_name_tmp - ) + self.get_weights_by_name_communicator.handle_recv(recv_obj) else: raise ValueError(f"Invalid object: {recv_obj=}") @@ -809,3 +793,28 @@ class SignalHandler: f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." ) self.tokenizer_manager.gracefully_exit = True + + +T = TypeVar("T") + + +class _Communicator(Generic[T]): + def __init__(self, sender, fan_out: int): + self._sender = sender + self._fan_out = fan_out + self._result_future: Optional[asyncio.Future] = None + self._result_values: Optional[List[T]] = None + + async def __call__(self, obj): + self._sender.send_pyobj(obj) + self._result_future = asyncio.Future() + self._result_values = [] + await self._result_future + result_values = self._result_values + self._result_future = self._result_values = None + return result_values + + def handle_recv(self, recv_obj: T): + self._result_values.append(recv_obj) + if len(self._result_values) == self._fan_out: + self._result_future.set_result(None) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b007d0d98..4814c8c6f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): try: ret = await tokenizer_manager.get_weights_by_name(obj, request) if ret is None: - return ORJSONResponse( - {"error": {"message": "Get parameter by name failed"}}, - status_code=HTTPStatus.BAD_REQUEST, - ) + return _create_error_response("Get parameter by name failed") else: return ORJSONResponse(ret, status_code=200) except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) @app.api_route("/open_session", methods=["GET", "POST"]) @@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): session_id = await tokenizer_manager.open_session(obj, request) return session_id except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) @app.api_route("/close_session", methods=["GET", "POST"]) @@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): await tokenizer_manager.close_session(obj, request) return Response(status_code=200) except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) # fastapi implicitly converts json in the request to obj (dataclass) @@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): return ret except ValueError as e: logger.error(f"Error: {e}") - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) @app.api_route("/encode", methods=["POST", "PUT"]) @@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) @app.api_route("/classify", methods=["POST", "PUT"]) @@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return _create_error_response(e) ##### OpenAI-compatible API endpoints ##### @@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str): return await v1_retrieve_file_content(file_id) +def _create_error_response(e): + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + def launch_engine( server_args: ServerArgs, ): @@ -849,12 +840,10 @@ class Engine: group_name=group_name, backend=backend, ) - - async def _init_group(): - return await tokenizer_manager.init_weights_update_group(obj, None) - loop = asyncio.get_event_loop() - return loop.run_until_complete(_init_group()) + return loop.run_until_complete( + tokenizer_manager.init_weights_update_group(obj, None) + ) def update_weights_from_distributed(self, name, dtype, shape): """Update weights from distributed source.""" @@ -863,22 +852,16 @@ class Engine: dtype=dtype, shape=shape, ) - - async def _update_weights(): - return await tokenizer_manager.update_weights_from_distributed(obj, None) - loop = asyncio.get_event_loop() - return loop.run_until_complete(_update_weights()) + return loop.run_until_complete( + tokenizer_manager.update_weights_from_distributed(obj, None) + ) def get_weights_by_name(self, name, truncate_size=100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - - async def _get_weights(): - return await tokenizer_manager.get_weights_by_name(obj, None) - loop = asyncio.get_event_loop() - return loop.run_until_complete(_get_weights()) + return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) class Runtime: