From e719bb0e84b7ce507323c523ec2c41386b43623e Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 7 Sep 2025 19:13:34 +0800 Subject: [PATCH] [1/2] Refactor multi-tokenizer manager (#10074) --- python/sglang/srt/entrypoints/engine.py | 36 +- python/sglang/srt/entrypoints/http_server.py | 63 +- .../srt/managers/detokenizer_manager.py | 8 +- python/sglang/srt/managers/disagg_service.py | 46 ++ .../srt/managers/multi_tokenizer_mixin.py | 706 ++++++++---------- .../sglang/srt/managers/tokenizer_manager.py | 47 +- 6 files changed, 421 insertions(+), 485 deletions(-) create mode 100644 python/sglang/srt/managers/disagg_service.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 5e5801fff..f704018e6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) +def _init_tokenizer_manager( + server_args: ServerArgs, port_args: PortArgs +) -> TokenizerManager: + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + + # Initialize templates + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + + return tokenizer_manager, template_manager + + def _launch_subprocesses( server_args: ServerArgs, port_args: Optional[PortArgs] = None ) -> Tuple[TokenizerManager, TemplateManager, Dict]: @@ -816,23 +834,15 @@ def _launch_subprocesses( ), ) detoken_proc.start() + + # Init tokenizer manager first, as the bootstrap server is initialized here if server_args.tokenizer_worker_num > 1: # Launch multi-tokenizer router tokenizer_manager = MultiTokenizerRouter(server_args, port_args) - - # Initialize templates template_manager = None else: - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) - - # Initialize templates - template_manager = TemplateManager() - template_manager.initialize_templates( - tokenizer_manager=tokenizer_manager, - model_path=server_args.model_path, - chat_template=server_args.chat_template, - completion_template=server_args.completion_template, + tokenizer_manager, template_manager = _init_tokenizer_manager( + server_args, port_args ) # Wait for the model to finish loading @@ -856,5 +866,7 @@ def _launch_subprocesses( # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, template_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index dc91d7e84..110292114 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.multi_tokenizer_mixin import ( MultiTokenizerManager, - deserialize_data, get_main_process_id, read_from_shared_memory, write_data_for_multi_tokenizer, @@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState): _global_state = global_state -# Function to set up all middlewares for multi-tokenizer compatibility -def setup_middlewares(api_key: Optional[str], enable_metrics: bool): - """Setup all middlewares for both single and multi-process modes""" - worker_pid = os.getpid() - - if api_key: - add_api_key_middleware(app, api_key) - logger.info(f"Worker {worker_pid} added API key middleware") - - if enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - logger.info(f"Worker {worker_pid} added prometheus middleware") - - async def init_multi_tokenizer() -> ServerArgs: """Read args information from shm and init tokenizer manager for current process""" pid = os.getpid() @@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs: logger.info(f"current worker_id: {pid}, main processID: {main_pid}") # Read configuration from shared memory - port_args_data = read_from_shared_memory(f"port_args_{main_pid}") - server_args_data = read_from_shared_memory(f"server_args_{main_pid}") - scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") - port_args, server_args = deserialize_data(port_args_data, server_args_data) - scheduler_info = scheduler_info_data + port_args, server_args, scheduler_info = read_from_shared_memory( + f"multi_tokenizer_args_{main_pid}" + ) + server_args: ServerArgs + + # API key authentication is not supported in multi-tokenizer mode + assert ( + server_args.api_key is None + ), "API key is not supported in multi-tokenizer mode" port_args.tokenizer_ipc_name = ( f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" @@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs: @asynccontextmanager async def lifespan(fast_api_app: FastAPI): - server_args = getattr(fast_api_app, "server_args", None) - if server_args is None: + if not getattr(fast_api_app, "is_single_tokenizer_mode", False): # Initialize multi-tokenizer support for worker processes - fast_api_app.server_args = await init_multi_tokenizer() - setup_middlewares( - fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics - ) + fast_api_app.server_args: ServerArgs = await init_multi_tokenizer() + + # only metrics middleware is supported in multi-tokenizer mode + worker_pid = os.getpid() + if fast_api_app.server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + logger.info(f"Worker {worker_pid} added prometheus middleware") fast_api_app.warmup_thread = threading.Thread( target=_wait_and_warmup, args=( @@ -1187,12 +1179,10 @@ def launch_server( ) if server_args.tokenizer_worker_num > 1: - port_args_shm, server_args_shm, scheduler_info_shm = ( - write_data_for_multi_tokenizer( - port_args, - server_args, - scheduler_info, - ) + multi_tokenizer_args_shm = write_data_for_multi_tokenizer( + port_args, + server_args, + scheduler_info, ) else: # Add api key authorization @@ -1239,6 +1229,7 @@ def launch_server( workers=server_args.tokenizer_worker_num, ) else: + app.is_single_tokenizer_mode = True uvicorn.run( app, host=server_args.host, @@ -1249,10 +1240,8 @@ def launch_server( ) finally: if server_args.tokenizer_worker_num > 1: - port_args_shm.unlink() - server_args_shm.unlink() - scheduler_info_shm.unlink() - _global_state.tokenizer_manager.clear_tokenizer_mapping() + multi_tokenizer_args_shm.unlink() + _global_state.tokenizer_manager.socket_mapping.clear_all_sockets() else: warmup_thread.join() diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 624d90e97..5c75d888b 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import ( FreezeGCReq, MultiTokenizerRegisterReq, ) -from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin +from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, @@ -69,7 +69,7 @@ class DecodeStatus: sent_offset: int = 0 -class DetokenizerManager(MultiTokenizerMixin): +class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): """DetokenizerManager is a process that detokenizes the token ids.""" def __init__( @@ -289,11 +289,11 @@ def run_detokenizer_process( try: manager = DetokenizerManager(server_args, port_args) if server_args.tokenizer_worker_num > 1: - manager.multi_tokenizer_manager_event_loop() + manager.multi_http_worker_event_loop() else: manager.event_loop() except Exception: - manager.clear_tokenizer_mapping() + manager.socket_mapping.clear_all_sockets() traceback = get_exception_traceback() logger.error(f"DetokenizerManager hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/disagg_service.py b/python/sglang/srt/managers/disagg_service.py new file mode 100644 index 000000000..df0eac48b --- /dev/null +++ b/python/sglang/srt/managers/disagg_service.py @@ -0,0 +1,46 @@ +"""Start bootstrap/kv-store-related server""" + +import os +from typing import Type + +from sglang.srt.disaggregation.base import BaseKVBootstrapServer +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + KVClassType, + TransferBackend, + get_kv_class, +) +from sglang.srt.server_args import ServerArgs + + +def start_disagg_service( + server_args: ServerArgs, +): + # Start kv boostrap server on prefill + disagg_mode = DisaggregationMode(server_args.disaggregation_mode) + transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend) + + if disagg_mode == DisaggregationMode.PREFILL: + # only start bootstrap server on prefill tm + kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class( + transfer_backend, KVClassType.BOOTSTRAP_SERVER + ) + bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class( + host=server_args.host, + port=server_args.disaggregation_bootstrap_port, + ) + is_create_store = ( + server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND + ) + if is_create_store: + try: + from mf_adapter import create_config_store + + ascend_url = os.getenv("ASCEND_MF_STORE_URL") + create_config_store(ascend_url) + except Exception as e: + error_message = f"Failed create mf store, invalid ascend_url." + error_message += f" With exception {e}" + raise error_message + + return bootstrap_server diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 94935152a..621989e03 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -13,21 +13,21 @@ # ============================================================================== """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" import asyncio -import dataclasses -import json import logging import multiprocessing as multiprocessing import os +import pickle import sys import threading from multiprocessing import shared_memory -from typing import Dict +from typing import Any, Dict import setproctitle import zmq import zmq.asyncio from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend +from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, BatchMultimodalOut, @@ -44,302 +44,296 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -class MultiTokenizerMixin: - """Mixin class for MultiTokenizerManager and DetokenizerManager""" +class SocketMapping: + def __init__(self): + self._zmq_context = zmq.Context() + self._mapping: Dict[str, zmq.Socket] = {} - def create_sockets_mapping(self): - if not hasattr(self, "tokenizer_mapping"): - self.tokenizer_mapping = {} - # Create ZMQ context if needed - if not hasattr(self, "_zmq_context"): - self._zmq_context = zmq.Context() + def clear_all_sockets(self): + for socket in self._mapping.values(): + socket.close() + self._mapping.clear() - def init_tokenizer_mapping( - self, recv_obj: MultiTokenizerRegisterReq, worker_id: str + def register_ipc_mapping( + self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool ): - """init tokenizer mapping from register request""" - ipc_name = recv_obj.ipc_name - worker_id_int = int(worker_id) + type_str = "tokenizer" if is_tokenizer else "detokenizer" + if worker_id in self._mapping: + logger.warning( + f"{type_str} already registered with worker {worker_id}, skipping..." + ) + return + logger.info( + f"{type_str} not registered with worker {worker_id}, registering..." + ) + socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False) + self._mapping[worker_id] = socket + self._mapping[worker_id].send_pyobj(recv_obj) - if worker_id_int not in self.tokenizer_mapping: - socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) - self.tokenizer_mapping[worker_id_int] = socket - self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj) - return True - else: - return False + def send_output(self, worker_id: str, output: Any): + if worker_id not in self._mapping: + logger.error( + f"worker ID {worker_id} not registered. Check if the server Process is alive" + ) + return + self._mapping[worker_id].send_pyobj(output) - def register_tokenizer_ipc(self, recv_obj, worker_id): - if worker_id not in self.tokenizer_mapping: - # register the worker if not already done - if isinstance(recv_obj, MultiTokenizerRegisterReq): - return self.init_tokenizer_mapping(recv_obj, worker_id) - else: - logger.error( - f"Worker {worker_id} not registered and not found in tokenizer mapping . " - "Please ensure the worker is registered correctly." - ) - return False - def _handle_output_by_index(self, output, i): - """NOTE: A maintainable method is better here.""" - if isinstance(output, BatchTokenIDOut): - new_output = BatchTokenIDOut( - rids=[output.rids[i]], - finished_reasons=( - [output.finished_reasons[i]] - if len(output.finished_reasons) > i - else None - ), - decoded_texts=( - [output.decoded_texts[i]] if len(output.decoded_texts) > i else None - ), - decode_ids=( - [output.decode_ids[i]] if len(output.decode_ids) > i else None - ), - read_offsets=( - [output.read_offsets[i]] if len(output.read_offsets) > i else None - ), - output_ids=( - [output.output_ids[i]] - if output.output_ids and len(output.output_ids) > i - else None - ), - skip_special_tokens=( - [output.skip_special_tokens[i]] - if len(output.skip_special_tokens) > i - else None - ), - spaces_between_special_tokens=( - [output.spaces_between_special_tokens[i]] - if len(output.spaces_between_special_tokens) > i - else None - ), - no_stop_trim=( - [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None - ), - prompt_tokens=( - [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None - ), - completion_tokens=( - [output.completion_tokens[i]] - if len(output.completion_tokens) > i - else None - ), - cached_tokens=( - [output.cached_tokens[i]] if len(output.cached_tokens) > i else None - ), - spec_verify_ct=( - [output.spec_verify_ct[i]] - if len(output.spec_verify_ct) > i - else None - ), - input_token_logprobs_val=( - [output.input_token_logprobs_val[i]] - if output.input_token_logprobs_val - else None - ), - input_token_logprobs_idx=( - [output.input_token_logprobs_idx[i]] - if output.input_token_logprobs_idx - else None - ), - output_token_logprobs_val=( - [output.output_token_logprobs_val[i]] - if output.output_token_logprobs_val - else None - ), - output_token_logprobs_idx=( - [output.output_token_logprobs_idx[i]] - if output.output_token_logprobs_idx - else None - ), - input_top_logprobs_val=( - [output.input_top_logprobs_val[i]] - if output.input_top_logprobs_val - else None - ), - input_top_logprobs_idx=( - [output.input_top_logprobs_idx[i]] - if output.input_top_logprobs_idx - else None - ), - output_top_logprobs_val=( - [output.output_top_logprobs_val[i]] - if output.output_top_logprobs_val - else None - ), - output_top_logprobs_idx=( - [output.output_top_logprobs_idx[i]] - if output.output_top_logprobs_idx - else None - ), - input_token_ids_logprobs_val=( - [output.input_token_ids_logprobs_val[i]] - if output.input_token_ids_logprobs_val - else None - ), - input_token_ids_logprobs_idx=( - [output.input_token_ids_logprobs_idx[i]] - if output.input_token_ids_logprobs_idx - else None - ), - output_token_ids_logprobs_val=( - [output.output_token_ids_logprobs_val[i]] - if output.output_token_ids_logprobs_val - else None - ), - output_token_ids_logprobs_idx=( - [output.output_token_ids_logprobs_idx[i]] - if output.output_token_ids_logprobs_idx - else None - ), - output_hidden_states=( - [output.output_hidden_states[i]] - if output.output_hidden_states - else None - ), - ) - elif isinstance(output, BatchEmbeddingOut): - new_output = BatchEmbeddingOut( - rids=[output.rids[i]], - finished_reasons=( - [output.finished_reasons[i]] - if len(output.finished_reasons) > i - else None - ), - embeddings=( - [output.embeddings[i]] if len(output.embeddings) > i else None - ), - prompt_tokens=( - [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None - ), - cached_tokens=( - [output.cached_tokens[i]] if len(output.cached_tokens) > i else None - ), - ) - elif isinstance(output, BatchStrOut): - new_output = BatchStrOut( - rids=[output.rids[i]], - finished_reasons=( - [output.finished_reasons[i]] - if len(output.finished_reasons) > i - else None - ), - output_strs=( - [output.output_strs[i]] if len(output.output_strs) > i else None - ), - output_ids=( - [output.output_ids[i]] - if output.output_ids and len(output.output_ids) > i - else None - ), - prompt_tokens=( - [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None - ), - completion_tokens=( - [output.completion_tokens[i]] - if len(output.completion_tokens) > i - else None - ), - cached_tokens=( - [output.cached_tokens[i]] if len(output.cached_tokens) > i else None - ), - spec_verify_ct=( - [output.spec_verify_ct[i]] - if len(output.spec_verify_ct) > i - else None - ), - input_token_logprobs_val=( - [output.input_token_logprobs_val[i]] - if output.input_token_logprobs_val - else None - ), - input_token_logprobs_idx=( - [output.input_token_logprobs_idx[i]] - if output.input_token_logprobs_idx - else None - ), - output_token_logprobs_val=( - [output.output_token_logprobs_val[i]] - if output.output_token_logprobs_val - else None - ), - output_token_logprobs_idx=( - [output.output_token_logprobs_idx[i]] - if output.output_token_logprobs_idx - else None - ), - input_top_logprobs_val=( - [output.input_top_logprobs_val[i]] - if output.input_top_logprobs_val - else None - ), - input_top_logprobs_idx=( - [output.input_top_logprobs_idx[i]] - if output.input_top_logprobs_idx - else None - ), - output_top_logprobs_val=( - [output.output_top_logprobs_val[i]] - if output.output_top_logprobs_val - else None - ), - output_top_logprobs_idx=( - [output.output_top_logprobs_idx[i]] - if output.output_top_logprobs_idx - else None - ), - input_token_ids_logprobs_val=( - [output.input_token_ids_logprobs_val[i]] - if output.input_token_ids_logprobs_val - else None - ), - input_token_ids_logprobs_idx=( - [output.input_token_ids_logprobs_idx[i]] - if output.input_token_ids_logprobs_idx - else None - ), - output_token_ids_logprobs_val=( - [output.output_token_ids_logprobs_val[i]] - if output.output_token_ids_logprobs_val - else None - ), - output_token_ids_logprobs_idx=( - [output.output_token_ids_logprobs_idx[i]] - if output.output_token_ids_logprobs_idx - else None - ), - output_hidden_states=( - [output.output_hidden_states[i]] - if output.output_hidden_states - else None - ), - ) - elif isinstance(output, BatchMultimodalOut): - new_output = BatchMultimodalOut( - rids=[output.rids[i]], - finished_reasons=( - [output.finished_reasons[i]] - if len(output.finished_reasons) > i - else None - ), - outputs=([output.outputs[i]] if len(output.outputs) > i else None), - prompt_tokens=( - [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None - ), - completion_tokens=( - [output.completion_tokens[i]] - if len(output.completion_tokens) > i - else None - ), - cached_tokens=( - [output.cached_tokens[i]] if len(output.cached_tokens) > i else None - ), - ) - else: - new_output = output - return new_output +def _handle_output_by_index(output, i): + """NOTE: A maintainable method is better here.""" + if isinstance(output, BatchTokenIDOut): + new_output = BatchTokenIDOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + decoded_texts=( + [output.decoded_texts[i]] if len(output.decoded_texts) > i else None + ), + decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None), + read_offsets=( + [output.read_offsets[i]] if len(output.read_offsets) > i else None + ), + output_ids=( + [output.output_ids[i]] + if output.output_ids and len(output.output_ids) > i + else None + ), + skip_special_tokens=( + [output.skip_special_tokens[i]] + if len(output.skip_special_tokens) > i + else None + ), + spaces_between_special_tokens=( + [output.spaces_between_special_tokens[i]] + if len(output.spaces_between_special_tokens) > i + else None + ), + no_stop_trim=( + [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None + ), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + spec_verify_ct=( + [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None + ), + input_token_logprobs_val=( + [output.input_token_logprobs_val[i]] + if output.input_token_logprobs_val + else None + ), + input_token_logprobs_idx=( + [output.input_token_logprobs_idx[i]] + if output.input_token_logprobs_idx + else None + ), + output_token_logprobs_val=( + [output.output_token_logprobs_val[i]] + if output.output_token_logprobs_val + else None + ), + output_token_logprobs_idx=( + [output.output_token_logprobs_idx[i]] + if output.output_token_logprobs_idx + else None + ), + input_top_logprobs_val=( + [output.input_top_logprobs_val[i]] + if output.input_top_logprobs_val + else None + ), + input_top_logprobs_idx=( + [output.input_top_logprobs_idx[i]] + if output.input_top_logprobs_idx + else None + ), + output_top_logprobs_val=( + [output.output_top_logprobs_val[i]] + if output.output_top_logprobs_val + else None + ), + output_top_logprobs_idx=( + [output.output_top_logprobs_idx[i]] + if output.output_top_logprobs_idx + else None + ), + input_token_ids_logprobs_val=( + [output.input_token_ids_logprobs_val[i]] + if output.input_token_ids_logprobs_val + else None + ), + input_token_ids_logprobs_idx=( + [output.input_token_ids_logprobs_idx[i]] + if output.input_token_ids_logprobs_idx + else None + ), + output_token_ids_logprobs_val=( + [output.output_token_ids_logprobs_val[i]] + if output.output_token_ids_logprobs_val + else None + ), + output_token_ids_logprobs_idx=( + [output.output_token_ids_logprobs_idx[i]] + if output.output_token_ids_logprobs_idx + else None + ), + output_hidden_states=( + [output.output_hidden_states[i]] + if output.output_hidden_states + else None + ), + ) + elif isinstance(output, BatchEmbeddingOut): + new_output = BatchEmbeddingOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + ) + elif isinstance(output, BatchStrOut): + new_output = BatchStrOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + output_strs=( + [output.output_strs[i]] if len(output.output_strs) > i else None + ), + output_ids=( + [output.output_ids[i]] + if output.output_ids and len(output.output_ids) > i + else None + ), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + spec_verify_ct=( + [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None + ), + input_token_logprobs_val=( + [output.input_token_logprobs_val[i]] + if output.input_token_logprobs_val + else None + ), + input_token_logprobs_idx=( + [output.input_token_logprobs_idx[i]] + if output.input_token_logprobs_idx + else None + ), + output_token_logprobs_val=( + [output.output_token_logprobs_val[i]] + if output.output_token_logprobs_val + else None + ), + output_token_logprobs_idx=( + [output.output_token_logprobs_idx[i]] + if output.output_token_logprobs_idx + else None + ), + input_top_logprobs_val=( + [output.input_top_logprobs_val[i]] + if output.input_top_logprobs_val + else None + ), + input_top_logprobs_idx=( + [output.input_top_logprobs_idx[i]] + if output.input_top_logprobs_idx + else None + ), + output_top_logprobs_val=( + [output.output_top_logprobs_val[i]] + if output.output_top_logprobs_val + else None + ), + output_top_logprobs_idx=( + [output.output_top_logprobs_idx[i]] + if output.output_top_logprobs_idx + else None + ), + input_token_ids_logprobs_val=( + [output.input_token_ids_logprobs_val[i]] + if output.input_token_ids_logprobs_val + else None + ), + input_token_ids_logprobs_idx=( + [output.input_token_ids_logprobs_idx[i]] + if output.input_token_ids_logprobs_idx + else None + ), + output_token_ids_logprobs_val=( + [output.output_token_ids_logprobs_val[i]] + if output.output_token_ids_logprobs_val + else None + ), + output_token_ids_logprobs_idx=( + [output.output_token_ids_logprobs_idx[i]] + if output.output_token_ids_logprobs_idx + else None + ), + output_hidden_states=( + [output.output_hidden_states[i]] + if output.output_hidden_states + else None + ), + ) + elif isinstance(output, BatchMultimodalOut): + new_output = BatchMultimodalOut( + rids=[output.rids[i]], + finished_reasons=( + [output.finished_reasons[i]] + if len(output.finished_reasons) > i + else None + ), + outputs=([output.outputs[i]] if len(output.outputs) > i else None), + prompt_tokens=( + [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None + ), + completion_tokens=( + [output.completion_tokens[i]] + if len(output.completion_tokens) > i + else None + ), + cached_tokens=( + [output.cached_tokens[i]] if len(output.cached_tokens) > i else None + ), + ) + else: + new_output = output + return new_output + + +class MultiHttpWorkerDetokenizerMixin: + """Mixin class for MultiTokenizerManager and DetokenizerManager""" def get_worker_ids_from_req_rids(self, rids): if isinstance(rids, list): @@ -350,9 +344,9 @@ class MultiTokenizerMixin: worker_ids = [] return worker_ids - def multi_tokenizer_manager_event_loop(self): - """The event loop that handles requests, for multi tokenizer manager mode only""" - self.create_sockets_mapping() + def multi_http_worker_event_loop(self): + """The event loop that handles requests, for multi multi-http-worker mode""" + self.socket_mapping = SocketMapping() while True: recv_obj = self.recv_from_scheduler.recv_pyobj() output = self._request_dispatcher(recv_obj) @@ -369,31 +363,15 @@ class MultiTokenizerMixin: # Send data using the corresponding socket for i, worker_id in enumerate(worker_ids): if isinstance(recv_obj, MultiTokenizerRegisterReq): - if self.register_tokenizer_ipc(recv_obj, worker_id): - logger.info( - f"DetokenizerManager Created ZMQ socket for worker {worker_id}" - ) - continue + self.socket_mapping.register_ipc_mapping( + recv_obj, worker_id, is_tokenizer=False + ) else: - if worker_id not in self.tokenizer_mapping: - logger.error( - f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive" - ) - continue - new_output = self._handle_output_by_index(output, i) - self.tokenizer_mapping[worker_id].send_pyobj(new_output) - - def clear_tokenizer_mapping(self): - if hasattr(self, "tokenizer_mapping"): - for socket in self.tokenizer_mapping.values(): - try: - socket.close() - except Exception as e: - logger.warning(f"Failed to close socket: {e}") - self.tokenizer_mapping.clear() + new_output = _handle_output_by_index(output, i) + self.socket_mapping.send_output(worker_id, new_output) -class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): +class MultiTokenizerRouter: """A router to receive requests from MultiTokenizerManager""" def __init__( @@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): self._handle_task = asyncio.run_coroutine_threadsafe( print_exception_wrapper(self.handle_loop), self._loop ) - self.init_disaggregation() + self.disaggregation_bootstrap_server = start_disagg_service(self.server_args) def _run_loop(self): self._loop.run_forever() @@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): async def handle_loop(self): # special reqs will recv from scheduler, need to route to right worker - self.create_sockets_mapping() + self.socket_mapping = SocketMapping() while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() await self._distribute_result_to_workers(recv_obj) @@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): # Distribute result to each worker for i, worker_id in enumerate(worker_ids): if isinstance(recv_obj, MultiTokenizerRegisterReq): - if self.register_tokenizer_ipc(recv_obj, worker_id): - logger.info( - f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}" - ) - continue + self.socket_mapping.register_ipc_mapping( + recv_obj, worker_id, is_tokenizer=True + ) else: - if worker_id not in self.tokenizer_mapping: - logger.error( - f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive" - ) - continue - new_recv_obj = self._handle_output_by_index(recv_obj, i) - self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj) + new_recv_obj = _handle_output_by_index(recv_obj, i) + self.socket_mapping.send_output(worker_id, new_recv_obj) -class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin): +class MultiTokenizerManager(TokenizerManager): """Multi Process Tokenizer Manager that tokenizes the text.""" def __init__( @@ -535,42 +506,14 @@ async def print_exception_wrapper(func): sys.exit(1) -def serialize_port_args(port_args: PortArgs) -> dict: - """Serialize PortArgs into a shareable dictionary""" - return { - "tokenizer_ipc_name": port_args.tokenizer_ipc_name, - "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name, - "detokenizer_ipc_name": port_args.detokenizer_ipc_name, - "nccl_port": port_args.nccl_port, - "rpc_ipc_name": port_args.rpc_ipc_name, - "metrics_ipc_name": port_args.metrics_ipc_name, - "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name, - } +def get_main_process_id() -> int: + """Get the main process ID""" + return multiprocessing.current_process()._parent_pid -def deserialize_data(port_args: dict, server_args: dict): - """Deserialize data from shared dictionaries""" - return PortArgs(**port_args), ServerArgs(**server_args) - - -def serialize_server_args(server_args: ServerArgs) -> dict: - """Serialize ServerArgs into a shareable dictionary""" - return dataclasses.asdict(server_args) - - -def serialize_scheduler_info(scheduler_info: Dict) -> dict: - """Serialize scheduler_info into a shareable dictionary""" - return scheduler_info - - -def deserialize_scheduler_info(data: dict) -> Dict: - """Deserialize scheduler_info from a shared dictionary""" - return data - - -def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: +def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory: """Write data to shared memory""" - serialized = json.dumps(data).encode("utf-8") + serialized = pickle.dumps(obj) size = len(serialized) try: # Try to open existing shared memory @@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: return shm -def read_from_shared_memory(name: str) -> dict: +def read_from_shared_memory(name: str) -> Any: """Read data from shared memory""" try: shm = shared_memory.SharedMemory(name=name) - data = json.loads(bytes(shm.buf).decode("utf-8")) + data = pickle.loads(bytes(shm.buf)) shm.close() return data except FileNotFoundError: raise FileNotFoundError(f"Shared memory {name} not found") -def get_main_process_id() -> int: - """Get the main process ID""" - return multiprocessing.current_process()._parent_pid - - def write_data_for_multi_tokenizer( port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict ): @@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer( main_pid = get_main_process_id() current_pid = os.getpid() logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}") + args = (port_args, server_args, scheduler_info) + args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}") + args_shm.close() - # Write port_args to shared memory - port_args_shm = write_to_shared_memory( - serialize_port_args(port_args), f"port_args_{current_pid}" - ) - # Write server_args to shared memory - server_args_shm = write_to_shared_memory( - serialize_server_args(server_args), f"server_args_{current_pid}" - ) - # Write scheduler_info to shared memory - scheduler_info_shm = write_to_shared_memory( - serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}" - ) - - port_args_shm.close() - server_args_shm.close() - scheduler_info_shm.close() - - return port_args_shm, server_args_shm, scheduler_info_shm + return args_shm diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d23d1a628..c00235587 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -54,19 +54,14 @@ from fastapi import BackgroundTasks from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.disaggregation.base import BaseKVBootstrapServer -from sglang.srt.disaggregation.utils import ( - DisaggregationMode, - KVClassType, - TransferBackend, - get_kv_class, -) +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import ( get_processor, get_tokenizer, get_tokenizer_from_processor, ) from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry +from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -321,8 +316,10 @@ class TokenizerManager: # LoRA updates and inference to overlap. self.lora_update_lock = asyncio.Lock() - # For PD disaggregtion - self.init_disaggregation() + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.bootstrap_server = start_disagg_service(self.server_args) # For load balancing self.current_load = 0 @@ -471,38 +468,6 @@ class TokenizerManager: ] ) - def init_disaggregation(self): - self.disaggregation_mode = DisaggregationMode( - self.server_args.disaggregation_mode - ) - self.disaggregation_transfer_backend = TransferBackend( - self.server_args.disaggregation_transfer_backend - ) - # Start kv boostrap server on prefill - if self.disaggregation_mode == DisaggregationMode.PREFILL: - # only start bootstrap server on prefill tm - kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class( - self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER - ) - self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class( - host=self.server_args.host, - port=self.server_args.disaggregation_bootstrap_port, - ) - is_create_store = ( - self.server_args.node_rank == 0 - and self.server_args.disaggregation_transfer_backend == "ascend" - ) - if is_create_store: - try: - from mf_adapter import create_config_store - - ascend_url = os.getenv("ASCEND_MF_STORE_URL") - create_config_store(ascend_url) - except Exception as e: - error_message = f"Failed create mf store, invalid ascend_url." - error_message += f" With exception {e}" - raise error_message - async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput],