diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 2b576b409..29df74b18 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) +from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -814,18 +815,24 @@ def _launch_subprocesses( ), ) detoken_proc.start() + if server_args.tokenizer_worker_num > 1: + # Launch multi-tokenizer router + tokenizer_manager = MultiTokenizerRouter(server_args, port_args) - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) + # Initialize templates + template_manager = None + else: + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) - # Initialize templates - template_manager = TemplateManager() - template_manager.initialize_templates( - tokenizer_manager=tokenizer_manager, - model_path=server_args.model_path, - chat_template=server_args.chat_template, - completion_template=server_args.completion_template, - ) + # Initialize templates + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) # Wait for the model to finish loading scheduler_infos = [] diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 5d6e03ac3..70d7deb1e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -23,6 +23,7 @@ import json import logging import multiprocessing as multiprocessing import os +import tempfile import threading import time from http import HTTPStatus @@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import ( UpdateWeightVersionReqInput, VertexGenerateReqInput, ) +from sglang.srt.managers.multi_tokenizer_mixin import ( + MultiTokenizerManager, + deserialize_data, + get_main_process_id, + read_from_shared_memory, + write_data_for_multi_tokenizer, +) from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.reasoning_parser import ReasoningParser -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, add_prometheus_middleware, @@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState): _global_state = global_state +# Function to set up all middlewares for multi-tokenizer compatibility +def setup_middlewares(api_key: Optional[str], enable_metrics: bool): + """Setup all middlewares for both single and multi-process modes""" + worker_pid = os.getpid() + + if api_key: + add_api_key_middleware(app, api_key) + logger.info(f"Worker {worker_pid} added API key middleware") + + if enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + logger.info(f"Worker {worker_pid} added prometheus middleware") + + +async def init_multi_tokenizer() -> ServerArgs: + """Read args information from shm and init tokenizer manager for current process""" + pid = os.getpid() + main_pid = get_main_process_id() + logger.info(f"current worker_id: {pid}, main processID: {main_pid}") + + # Read configuration from shared memory + port_args_data = read_from_shared_memory(f"port_args_{main_pid}") + server_args_data = read_from_shared_memory(f"server_args_{main_pid}") + scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") + port_args, server_args = deserialize_data(port_args_data, server_args_data) + scheduler_info = scheduler_info_data + + port_args.tokenizer_ipc_name = ( + f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + ) + + # Launch multi-tokenizer manager process + tokenizer_manager = MultiTokenizerManager(server_args, port_args) + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + # Register this tokenizer with the main tokenizer manager + await tokenizer_manager.register_to_main_tokenizer_manager() + + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + template_manager=template_manager, + scheduler_info=scheduler_info, + ) + ) + return server_args + + @asynccontextmanager async def lifespan(fast_api_app: FastAPI): + server_args = getattr(fast_api_app, "server_args", None) + if server_args is None: + # Initialize multi-tokenizer support for worker processes + fast_api_app.server_args = await init_multi_tokenizer() + setup_middlewares( + fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics + ) + fast_api_app.warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=( + fast_api_app.server_args, + None, # pipe_finish_writer not needed in worker + None, # launch_callback not needed in worker + ), + ) + # Initialize OpenAI serving handlers fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( _global_state.tokenizer_manager, _global_state.template_manager @@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI): warmup_thread = getattr(fast_api_app, "warmup_thread", None) if warmup_thread is not None: warmup_thread.start() - yield + + try: + yield + finally: + if server_args.tokenizer_worker_num > 1: + pid = os.getpid() + logger.info(f"uvicorn worker {pid} ending...") + warmup_thread.join() + logger.info(f"uvicorn worker {pid} ended.") # Fast API @@ -1078,9 +1165,19 @@ def launch_server( 1. The HTTP server, Engine, and TokenizerManager both run in the main process. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ - tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( - server_args=server_args - ) + if server_args.tokenizer_worker_num > 1: + port_args = PortArgs.init_new(server_args) + port_args.tokenizer_worker_ipc_name = ( + f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + ) + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args, port_args=port_args + ) + else: + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args, + ) + set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, @@ -1089,42 +1186,75 @@ def launch_server( ) ) - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) + if server_args.tokenizer_worker_num > 1: + port_args_shm, server_args_shm, scheduler_info_shm = ( + write_data_for_multi_tokenizer( + port_args, + server_args, + scheduler_info, + ) + ) + else: + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() - # Send a warmup request - we will create the thread launch it - # in the lifespan after all other warmups have fired. - warmup_thread = threading.Thread( - target=_wait_and_warmup, - args=( - server_args, - pipe_finish_writer, - launch_callback, - ), - ) - app.warmup_thread = warmup_thread + # Send a warmup request - we will create the thread launch it + # in the lifespan after all other warmups have fired. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + launch_callback, + ), + ) + app.warmup_thread = warmup_thread try: # Update logging configs set_uvicorn_logging_configs() app.server_args = server_args # Listen for HTTP requests - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) + if server_args.tokenizer_worker_num > 1: + from uvicorn.config import LOGGING_CONFIG + + LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + } + uvicorn.run( + "sglang.srt.entrypoints.http_server:app", + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + workers=server_args.tokenizer_worker_num, + ) + else: + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) finally: - warmup_thread.join() + if server_args.tokenizer_worker_num > 1: + port_args_shm.unlink() + server_args_shm.unlink() + scheduler_info_shm.unlink() + _global_state.tokenizer_manager.clear_tokenizer_mapping() + else: + warmup_thread.join() def _execute_server_warmup( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index c86149907..83abd2331 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import ( BatchStrOut, BatchTokenIDOut, FreezeGCReq, + MultiTokenizerRegisterReq, ) +from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, freeze_gc, + get_worker_ids_from_req_rids, get_zmq_socket, kill_itself_when_parent_died, ) @@ -67,7 +70,7 @@ class DecodeStatus: sent_offset: int = 0 -class DetokenizerManager: +class DetokenizerManager(MultiTokenizerMixin): """DetokenizerManager is a process that detokenizes the token ids.""" def __init__( @@ -102,6 +105,7 @@ class DetokenizerManager: (BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchTokenIDOut, self.handle_batch_token_id_out), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), + (MultiTokenizerRegisterReq, lambda x: x), (FreezeGCReq, self.handle_freeze_gc_req), ] ) @@ -116,6 +120,39 @@ class DetokenizerManager: if output is not None: self.send_to_tokenizer.send_pyobj(output) + def multi_tokenizer_manager_event_loop(self): + """The event loop that handles requests, for multi tokenizer manager mode only""" + self.create_sockets_mapping() + while True: + recv_obj = self.recv_from_scheduler.recv_pyobj() + output = self._request_dispatcher(recv_obj) + if output is None: + continue + # Extract worker_id from rid + if isinstance(recv_obj.rids, list): + worker_ids = get_worker_ids_from_req_rids(recv_obj.rids) + else: + raise RuntimeError( + f"for tokenizer_worker_num > 1, recv_obj.rids must be a list" + ) + + # Send data using the corresponding socket + for i, worker_id in enumerate(worker_ids): + if isinstance(recv_obj, MultiTokenizerRegisterReq): + if self.register_tokenizer_ipc(recv_obj, worker_id): + logger.info( + f"DetokenizerManager Created ZMQ socket for worker {worker_id}" + ) + continue + else: + if worker_id not in self.tokenizer_mapping: + logger.error( + f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive" + ) + continue + new_output = self._handle_output_by_index(output, i) + self.tokenizer_mapping[worker_id].send_pyobj(new_output) + def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool ): @@ -285,8 +322,12 @@ def run_detokenizer_process( try: manager = DetokenizerManager(server_args, port_args) - manager.event_loop() + if server_args.tokenizer_worker_num > 1: + manager.multi_tokenizer_manager_event_loop() + else: + manager.event_loop() except Exception: + manager.clear_tokenizer_mapping() traceback = get_exception_traceback() logger.error(f"DetokenizerManager hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 917d387fe..1a99e0b5a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -983,6 +983,11 @@ class AbortReq: abort_all: bool = False # The finished reason data finished_reason: Optional[Dict[str, Any]] = None + # used in MultiTokenzierManager mode + rids: Optional[Union[List[str], str]] = None + + def __post_init__(self): + self.rids = self.rid @dataclass @@ -1183,6 +1188,18 @@ class LoRAUpdateResult: LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult +@dataclass +class MultiTokenizerRegisterReq: + rids: Optional[Union[List[str], str]] = None + ipc_name: Optional[str] = None + + +@dataclass +class MultiTokenizerWarpper: + worker_id: int + obj: Optional[Any] = None + + class BlockReqType(Enum): BLOCK = 1 UNBLOCK = 2 diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py new file mode 100644 index 000000000..86d057457 --- /dev/null +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -0,0 +1,591 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" +import asyncio +import dataclasses +import json +import logging +import multiprocessing as multiprocessing +import os +import sys +import threading +from multiprocessing import shared_memory +from typing import Dict + +import zmq +import zmq.asyncio + +from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend +from sglang.srt.managers.io_struct import ( + BatchEmbeddingOut, + BatchMultimodalOut, + BatchStrOut, + BatchTokenIDOut, + MultiTokenizerRegisterReq, + MultiTokenizerWarpper, +) +from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + get_worker_ids_from_req_rids, + get_zmq_socket, + kill_process_tree, +) +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class MultiTokenizerMixin: + """Mixin class for MultiTokenizerManager and DetokenizerManager""" + + def create_sockets_mapping(self): + if not hasattr(self, "tokenizer_mapping"): + self.tokenizer_mapping = {} + # Create ZMQ context if needed + if not hasattr(self, "_zmq_context"): + self._zmq_context = zmq.Context() + + def init_tokenizer_mapping( + self, recv_obj: MultiTokenizerRegisterReq, worker_id: str + ): + """init tokenizer mapping from register request""" + ipc_name = recv_obj.ipc_name + worker_id_int = int(worker_id) + + if worker_id_int not in self.tokenizer_mapping: + socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) + self.tokenizer_mapping[worker_id_int] = socket + self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj) + return True + else: + return False + + def register_tokenizer_ipc(self, recv_obj, worker_id): + if worker_id not in self.tokenizer_mapping: + # register the worker if not already done + if isinstance(recv_obj, MultiTokenizerRegisterReq): + return self.init_tokenizer_mapping(recv_obj, worker_id) + else: + logger.error( + f"Worker {worker_id} not registered and not found in tokenizer mapping . " + "Please ensure the worker is registered correctly." + ) + return False + + def _handle_output_by_index(self, output, i): + """NOTE: A maintainable method is better here.""" + if isinstance(output, BatchTokenIDOut): + new_output = BatchTokenIDOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + decoded_texts=( + [output.decoded_texts[i]] if len(output.decoded_texts) > i else None + ), + decode_ids=( + [output.decode_ids[i]] if len(output.decode_ids) > i else None + ), + read_offsets=( + [output.read_offsets[i]] if len(output.read_offsets) > i else None + ), + output_ids=( + [output.output_ids[i]] + if output.output_ids and len(output.output_ids) > i + else None + ), + skip_special_tokens=( + [output.skip_special_tokens[i]] + if len(output.skip_special_tokens) > i + else None + ), + spaces_between_special_tokens=( + [output.spaces_between_special_tokens[i]] + if len(output.spaces_between_special_tokens) > i + else None + ), + no_stop_trim=( + [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None + ), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + spec_verify_ct=( + [output.spec_verify_ct[i]] + if len(output.spec_verify_ct) > i + else None + ), + input_token_logprobs_val=( + [output.input_token_logprobs_val[i]] + if output.input_token_logprobs_val + else None + ), + input_token_logprobs_idx=( + [output.input_token_logprobs_idx[i]] + if output.input_token_logprobs_idx + else None + ), + output_token_logprobs_val=( + [output.output_token_logprobs_val[i]] + if output.output_token_logprobs_val + else None + ), + output_token_logprobs_idx=( + [output.output_token_logprobs_idx[i]] + if output.output_token_logprobs_idx + else None + ), + input_top_logprobs_val=( + [output.input_top_logprobs_val[i]] + if output.input_top_logprobs_val + else None + ), + input_top_logprobs_idx=( + [output.input_top_logprobs_idx[i]] + if output.input_top_logprobs_idx + else None + ), + output_top_logprobs_val=( + [output.output_top_logprobs_val[i]] + if output.output_top_logprobs_val + else None + ), + output_top_logprobs_idx=( + [output.output_top_logprobs_idx[i]] + if output.output_top_logprobs_idx + else None + ), + input_token_ids_logprobs_val=( + [output.input_token_ids_logprobs_val[i]] + if output.input_token_ids_logprobs_val + else None + ), + input_token_ids_logprobs_idx=( + [output.input_token_ids_logprobs_idx[i]] + if output.input_token_ids_logprobs_idx + else None + ), + output_token_ids_logprobs_val=( + [output.output_token_ids_logprobs_val[i]] + if output.output_token_ids_logprobs_val + else None + ), + output_token_ids_logprobs_idx=( + [output.output_token_ids_logprobs_idx[i]] + if output.output_token_ids_logprobs_idx + else None + ), + output_hidden_states=( + [output.output_hidden_states[i]] + if output.output_hidden_states + else None + ), + ) + elif isinstance(output, BatchEmbeddingOut): + new_output = BatchEmbeddingOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + embeddings=( + [output.embeddings[i]] if len(output.embeddings) > i else None + ), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + ) + elif isinstance(output, BatchStrOut): + new_output = BatchStrOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + output_strs=( + [output.output_strs[i]] if len(output.output_strs) > i else None + ), + output_ids=( + [output.output_ids[i]] + if output.output_ids and len(output.output_ids) > i + else None + ), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + spec_verify_ct=( + [output.spec_verify_ct[i]] + if len(output.spec_verify_ct) > i + else None + ), + input_token_logprobs_val=( + [output.input_token_logprobs_val[i]] + if output.input_token_logprobs_val + else None + ), + input_token_logprobs_idx=( + [output.input_token_logprobs_idx[i]] + if output.input_token_logprobs_idx + else None + ), + output_token_logprobs_val=( + [output.output_token_logprobs_val[i]] + if output.output_token_logprobs_val + else None + ), + output_token_logprobs_idx=( + [output.output_token_logprobs_idx[i]] + if output.output_token_logprobs_idx + else None + ), + input_top_logprobs_val=( + [output.input_top_logprobs_val[i]] + if output.input_top_logprobs_val + else None + ), + input_top_logprobs_idx=( + [output.input_top_logprobs_idx[i]] + if output.input_top_logprobs_idx + else None + ), + output_top_logprobs_val=( + [output.output_top_logprobs_val[i]] + if output.output_top_logprobs_val + else None + ), + output_top_logprobs_idx=( + [output.output_top_logprobs_idx[i]] + if output.output_top_logprobs_idx + else None + ), + input_token_ids_logprobs_val=( + [output.input_token_ids_logprobs_val[i]] + if output.input_token_ids_logprobs_val + else None + ), + input_token_ids_logprobs_idx=( + [output.input_token_ids_logprobs_idx[i]] + if output.input_token_ids_logprobs_idx + else None + ), + output_token_ids_logprobs_val=( + [output.output_token_ids_logprobs_val[i]] + if output.output_token_ids_logprobs_val + else None + ), + output_token_ids_logprobs_idx=( + [output.output_token_ids_logprobs_idx[i]] + if output.output_token_ids_logprobs_idx + else None + ), + output_hidden_states=( + [output.output_hidden_states[i]] + if output.output_hidden_states + else None + ), + ) + elif isinstance(output, BatchMultimodalOut): + new_output = BatchMultimodalOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + outputs=([output.outputs[i]] if len(output.outputs) > i else None), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + ) + else: + new_output = output + return new_output + + def clear_tokenizer_mapping(self): + if hasattr(self, "tokenizer_mapping"): + for socket in self.tokenizer_mapping.values(): + try: + socket.close() + except Exception as e: + logger.warning(f"Failed to close socket: {e}") + self.tokenizer_mapping.clear() + + +class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): + """A router to receive requests from MultiTokenizerManager""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + self.server_args = server_args + context = zmq.asyncio.Context(3) + self.recv_from_detokenizer = get_zmq_socket( + context, zmq.PULL, port_args.tokenizer_ipc_name, True + ) + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True + ) + self.receive_from_worker = get_zmq_socket( + context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True + ) + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + self._task = asyncio.run_coroutine_threadsafe( + self.router_worker_obj(), self._loop + ) + # Start handle_loop simultaneously + self._handle_task = asyncio.run_coroutine_threadsafe( + print_exception_wrapper(self.handle_loop), self._loop + ) + self.init_disaggregation() + + def _run_loop(self): + self._loop.run_forever() + + async def router_worker_obj(self): + while True: + recv_obj = await self.receive_from_worker.recv_pyobj() + await self.send_to_scheduler.send_pyobj(recv_obj) + + async def handle_loop(self): + # special reqs will recv from scheduler, need to route to right worker + self.create_sockets_mapping() + while True: + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + await self._distribute_result_to_workers(recv_obj) + + async def _distribute_result_to_workers(self, recv_obj): + """Distribute result to corresponding workers based on rid""" + if isinstance(recv_obj, MultiTokenizerWarpper): + worker_ids = [recv_obj.worker_id] + recv_obj = recv_obj.obj + else: + worker_ids = get_worker_ids_from_req_rids(recv_obj.rids) + + if len(worker_ids) == 0: + logger.error(f"Cannot find worker_id from rids {recv_obj.rids}") + return + + # Distribute result to each worker + for i, worker_id in enumerate(worker_ids): + if isinstance(recv_obj, MultiTokenizerRegisterReq): + if self.register_tokenizer_ipc(recv_obj, worker_id): + logger.info( + f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}" + ) + continue + else: + if worker_id not in self.tokenizer_mapping: + logger.error( + f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive" + ) + continue + new_recv_obj = self._handle_output_by_index(recv_obj, i) + self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj) + + +class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin): + """Multi Process Tokenizer Manager that tokenizes the text.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + # prevent init prefill bootstrapserver again + disaggregation_mode = server_args.disaggregation_mode + server_args.disaggregation_mode = "null" + super().__init__(server_args, port_args) + + self.worker_id = os.getpid() + self.tokenizer_ipc_name = port_args.tokenizer_ipc_name + + # For PD disaggregtion + self.server_args.disaggregation_mode = disaggregation_mode + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.disaggregation_transfer_backend = TransferBackend( + self.server_args.disaggregation_transfer_backend + ) + # Communicator + self.register_multi_tokenizer_communicator = _Communicator( + self.send_to_scheduler, 2 + ) + self._result_dispatcher._mapping.append( + ( + MultiTokenizerRegisterReq, + self.register_multi_tokenizer_communicator.handle_recv, + ) + ) + + async def register_to_main_tokenizer_manager(self): + """Register this worker to the main TokenizerManager""" + # create a handle loop to receive messages from the main TokenizerManager + self.auto_create_handle_loop() + req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"]) + req.ipc_name = self.tokenizer_ipc_name + _Communicator.enable_multi_tokenizer = True + await self.register_multi_tokenizer_communicator(req) + + +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"MultiTokenizerRouter hit an exception: {traceback}") + if hasattr(func, "__self__") and isinstance( + func.__self__, MultiTokenizerRouter + ): + func.__self__.dump_requests_before_crash() + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) + + +def serialize_port_args(port_args: PortArgs) -> dict: + """Serialize PortArgs into a shareable dictionary""" + return { + "tokenizer_ipc_name": port_args.tokenizer_ipc_name, + "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name, + "detokenizer_ipc_name": port_args.detokenizer_ipc_name, + "nccl_port": port_args.nccl_port, + "rpc_ipc_name": port_args.rpc_ipc_name, + "metrics_ipc_name": port_args.metrics_ipc_name, + "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name, + } + + +def deserialize_data(port_args: dict, server_args: dict): + """Deserialize data from shared dictionaries""" + return PortArgs(**port_args), ServerArgs(**server_args) + + +def serialize_server_args(server_args: ServerArgs) -> dict: + """Serialize ServerArgs into a shareable dictionary""" + return dataclasses.asdict(server_args) + + +def serialize_scheduler_info(scheduler_info: Dict) -> dict: + """Serialize scheduler_info into a shareable dictionary""" + return scheduler_info + + +def deserialize_scheduler_info(data: dict) -> Dict: + """Deserialize scheduler_info from a shared dictionary""" + return data + + +def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: + """Write data to shared memory""" + serialized = json.dumps(data).encode("utf-8") + size = len(serialized) + try: + # Try to open existing shared memory + shm = shared_memory.SharedMemory(name=name) + # If size is insufficient, close and recreate + if shm.size < size: + shm.close() + shm.unlink() + shm = shared_memory.SharedMemory(create=True, size=size, name=name) + except FileNotFoundError: + # If not present, create new shared memory + shm = shared_memory.SharedMemory(create=True, size=size, name=name) + + shm.buf[:size] = serialized + return shm + + +def read_from_shared_memory(name: str) -> dict: + """Read data from shared memory""" + try: + shm = shared_memory.SharedMemory(name=name) + data = json.loads(bytes(shm.buf).decode("utf-8")) + shm.close() + return data + except FileNotFoundError: + raise FileNotFoundError(f"Shared memory {name} not found") + + +def get_main_process_id() -> int: + """Get the main process ID""" + return multiprocessing.current_process()._parent_pid + + +def write_data_for_multi_tokenizer( + port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict +): + """Write args information to share memory for multi-tokenizer""" + # get main process ID + main_pid = get_main_process_id() + current_pid = os.getpid() + logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}") + + # Write port_args to shared memory + port_args_shm = write_to_shared_memory( + serialize_port_args(port_args), f"port_args_{current_pid}" + ) + # Write server_args to shared memory + server_args_shm = write_to_shared_memory( + serialize_server_args(server_args), f"server_args_{current_pid}" + ) + # Write scheduler_info to shared memory + scheduler_info_shm = write_to_shared_memory( + serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}" + ) + + port_args_shm.close() + server_args_shm.close() + scheduler_info_shm.close() + + return port_args_shm, server_args_shm, scheduler_info_shm diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 38ff0ef14..4bf76f78b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, + MultiTokenizerRegisterReq, + MultiTokenizerWarpper, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -257,7 +259,6 @@ class Scheduler( # Init inter-process communication context = zmq.Context(2) self.idle_sleeper = None - if self.pp_rank == 0 and self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False @@ -540,6 +541,7 @@ class Scheduler( (ExpertDistributionReq, self.expert_distribution_handle), (LoadLoRAAdapterReqInput, self.load_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter), + (MultiTokenizerRegisterReq, self.register_multi_tokenizer), ] ) @@ -1101,6 +1103,17 @@ class Scheduler( ) self.send_to_tokenizer.send_pyobj(abort_req) continue + + # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request. + if isinstance(recv_req, MultiTokenizerWarpper): + worker_id = recv_req.worker_id + recv_req = recv_req.obj + output = self._request_dispatcher(recv_req) + if output is not None: + output = MultiTokenizerWarpper(worker_id, output) + self.send_to_tokenizer.send_pyobj(output) + continue + output = self._request_dispatcher(recv_req) if output is not None: if isinstance(output, RpcReqOutput): @@ -2474,6 +2487,10 @@ class Scheduler( result = self.tp_worker.unload_lora_adapter(recv_req) return result + def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq): + self.send_to_detokenizer.send_pyobj(recv_req) + return recv_req + def slow_down(self, recv_req: SlowDownReqInput): t = recv_req.forward_sleep_time if t is not None and t <= 0: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 36fd4964b..53c6a8036 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import ( LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, LoRAUpdateResult, + MultiTokenizerWarpper, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -131,6 +132,7 @@ from sglang.srt.utils import ( dataclass_to_string_truncated, freeze_gc, get_bool_env_var, + get_origin_rid, get_zmq_socket, kill_process_tree, ) @@ -266,9 +268,15 @@ class TokenizerManager: self.recv_from_detokenizer = get_zmq_socket( context, zmq.PULL, port_args.tokenizer_ipc_name, True ) - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name, True - ) + if self.server_args.tokenizer_worker_num > 1: + # Use tokenizer_worker_ipc_name in multi-tokenizer mode + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False + ) + else: + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True + ) # Request states self.no_create_loop = False @@ -312,35 +320,7 @@ class TokenizerManager: self.lora_update_lock = asyncio.Lock() # For PD disaggregtion - self.disaggregation_mode = DisaggregationMode( - self.server_args.disaggregation_mode - ) - self.disaggregation_transfer_backend = TransferBackend( - self.server_args.disaggregation_transfer_backend - ) - # Start kv boostrap server on prefill - if self.disaggregation_mode == DisaggregationMode.PREFILL: - # only start bootstrap server on prefill tm - kv_bootstrap_server_class = get_kv_class( - self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER - ) - self.bootstrap_server = kv_bootstrap_server_class( - self.server_args.disaggregation_bootstrap_port - ) - is_create_store = ( - self.server_args.node_rank == 0 - and self.server_args.disaggregation_transfer_backend == "ascend" - ) - if is_create_store: - try: - from mf_adapter import create_config_store - - ascend_url = os.getenv("ASCEND_MF_STORE_URL") - create_config_store(ascend_url) - except Exception as e: - error_message = f"Failed create mf store, invalid ascend_url." - error_message += f" With exception {e}" - raise error_message + self.init_disaggregation() # For load balancing self.current_load = 0 @@ -488,6 +468,37 @@ class TokenizerManager: ] ) + def init_disaggregation(self): + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.disaggregation_transfer_backend = TransferBackend( + self.server_args.disaggregation_transfer_backend + ) + # Start kv boostrap server on prefill + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # only start bootstrap server on prefill tm + kv_bootstrap_server_class = get_kv_class( + self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER + ) + self.bootstrap_server = kv_bootstrap_server_class( + self.server_args.disaggregation_bootstrap_port + ) + is_create_store = ( + self.server_args.node_rank == 0 + and self.server_args.disaggregation_transfer_backend == "ascend" + ) + if is_create_store: + try: + from mf_adapter import create_config_store + + ascend_url = os.getenv("ASCEND_MF_STORE_URL") + create_config_store(ascend_url) + except Exception as e: + error_message = f"Failed create mf store, invalid ascend_url." + error_message += f" With exception {e}" + raise error_message + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -497,6 +508,15 @@ class TokenizerManager: self.auto_create_handle_loop() obj.normalize_batch_and_arguments() + if self.server_args.tokenizer_worker_num > 1: + # Modify rid, add worker_id + if isinstance(obj.rid, list): + # If it's an array, add worker_id prefix to each element + obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid] + else: + # If it's a single value, add worker_id prefix + obj.rid = f"{self.worker_id}_{obj.rid}" + if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( @@ -1096,6 +1116,8 @@ class TokenizerManager: async def _wait_for_model_update_from_disk( self, obj: UpdateWeightFromDiskReqInput ) -> Tuple[bool, str]: + if self.server_args.tokenizer_worker_num > 1: + obj = MultiTokenizerWarpper(self.worker_id, obj) self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future() if self.server_args.dp_size == 1: @@ -1315,6 +1337,8 @@ class TokenizerManager: elif obj.session_id in self.session_futures: return None + if self.server_args.tokenizer_worker_num > 1: + obj = MultiTokenizerWarpper(self.worker_id, obj) self.send_to_scheduler.send_pyobj(obj) self.session_futures[obj.session_id] = asyncio.Future() @@ -1590,7 +1614,6 @@ class TokenizerManager: async def handle_loop(self): """The event loop that handles requests""" - while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) @@ -1610,9 +1633,12 @@ class TokenizerManager: ) continue + origin_rid = rid + if self.server_args.tokenizer_worker_num > 1: + origin_rid = get_origin_rid(rid) # Build meta_info and return value meta_info = { - "id": rid, + "id": origin_rid, "finish_reason": recv_obj.finished_reasons[i], "prompt_tokens": recv_obj.prompt_tokens[i], "weight_version": self.server_args.weight_version, @@ -1918,6 +1944,9 @@ class TokenizerManager: if is_health_check_generate_req(recv_obj): return state = self.rid_to_state[recv_obj.rid] + origin_rid = recv_obj.rid + if self.server_args.tokenizer_worker_num > 1: + origin_rid = get_origin_rid(origin_rid) state.finished = True if recv_obj.finished_reason: out = { @@ -1930,7 +1959,7 @@ class TokenizerManager: out = { "text": "", "meta_info": { - "id": recv_obj.rid, + "id": origin_rid, "finish_reason": { "type": "abort", "message": "Abort before prefill", @@ -2116,6 +2145,8 @@ T = TypeVar("T") class _Communicator(Generic[T]): """Note: The communicator now only run up to 1 in-flight request at any time.""" + enable_multi_tokenizer = False + def __init__(self, sender, fan_out: int): self._sender = sender self._fan_out = fan_out @@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]): assert self._result_values is None if obj: + if _Communicator.enable_multi_tokenizer: + obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj) self._sender.send_pyobj(obj) self._result_event = asyncio.Event() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8114a81aa..eaf4a5869 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -128,6 +128,7 @@ class ServerArgs: model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" + tokenizer_worker_num: int = 1 skip_tokenizer_init: bool = False load_format: str = "auto" model_loader_extra_config: str = "{}" @@ -827,6 +828,12 @@ class ServerArgs: default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) + parser.add_argument( + "--tokenizer-worker-num", + type=int, + default=ServerArgs.tokenizer_worker_num, + help="The worker num of the tokenizer manager.", + ) parser.add_argument( "--tokenizer-mode", type=str, @@ -2176,6 +2183,9 @@ class ServerArgs: self.chunked_prefill_size % self.page_size == 0 ), "chunked_prefill_size must be divisible by page_size" + # Check multi tokenizer + assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1" + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -2419,6 +2429,9 @@ class PortArgs: # The ipc filename for Scheduler to send metrics metrics_ipc_name: str + # The ipc filename for Tokenizer and worker tokenizer + tokenizer_worker_ipc_name: Optional[str] + @staticmethod def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": if server_args.nccl_port is None: @@ -2442,6 +2455,7 @@ class PortArgs: nccl_port=nccl_port, rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + tokenizer_worker_ipc_name=None, ) else: # DP attention. Use TCP + port to handle both single-node and multi-node. @@ -2475,6 +2489,7 @@ class PortArgs: nccl_port=nccl_port, rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", + tokenizer_worker_ipc_name=None, ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b5f6626a2..ae175b8c7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128): return decorator +def get_worker_ids_from_req_rids(rids): + if isinstance(rids, list): + worker_ids = [int(rid.split("_")[0]) for rid in rids] + elif isinstance(rids, str): + worker_ids = [int(rids.split("_")[0])] + else: + worker_ids = [] + return worker_ids + + +def get_origin_rid(rid): + return rid.split("_", 1)[1] if "_" in rid else rid + + def apply_module_patch(target_module, target_function, wrappers): original_module, original_function = parse_module_path( target_module, target_function, False diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cd219f082..8b4310f43 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -85,6 +85,7 @@ suites = { TestFile("test_mla_int8_deepseek_v3.py", 429), TestFile("test_mla_flashinfer.py", 302), TestFile("test_mla_fp8.py", 93), + TestFile("test_multi_tokenizer.py", 230), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 234), TestFile("test_original_logprobs.py", 200), diff --git a/test/srt/test_multi_tokenizer.py b/test/srt/test_multi_tokenizer.py new file mode 100644 index 000000000..182454e5e --- /dev/null +++ b/test/srt/test_multi_tokenizer.py @@ -0,0 +1,84 @@ +import unittest +from types import SimpleNamespace + +import sglang.srt.managers.io_struct as io_struct +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + auto_config_device, + get_benchmark_args, + is_in_ci, + popen_launch_server, + run_benchmark, + write_github_step_summary, +) + + +class TestMultiTokenizer(CustomTestCase): + # from test_hicache.py + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tokenizer-worker-num", + 8, + "--mem-fraction-static", + 0.7, + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + def test_multi_tokenizer_ttft(self): + # from test_bench_serving.py run_bench_serving + args = get_benchmark_args( + base_url=self.base_url, + dataset_name="random", + dataset_path="", + tokenizer=None, + num_prompts=100, + random_input_len=4096, + random_output_len=2048, + sharegpt_context_len=None, + request_rate=1, + disable_stream=False, + disable_ignore_eos=False, + seed=0, + device=auto_config_device(), + lora_name=None, + ) + res = run_benchmark(args) + if is_in_ci(): + write_github_step_summary( + f"### test_multi_tokenizer_ttft\n" + f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n" + ) + self.assertLess(res["median_e2e_latency_ms"], 11000) + self.assertLess(res["median_ttft_ms"], 86) + self.assertLess(res["median_itl_ms"], 10) + + +if __name__ == "__main__": + unittest.main()