From 260fe755b648f636e0256aa327c959219276c6ad Mon Sep 17 00:00:00 2001 From: Zhengke Zhou Date: Tue, 21 Oct 2025 16:33:29 +0800 Subject: [PATCH] Simplify multi-tokenizer (#11295) Signed-off-by: zhengkezhou1 Co-authored-by: Liangsheng Yin --- python/sglang/srt/entrypoints/http_server.py | 11 +- .../srt/managers/detokenizer_manager.py | 4 +- python/sglang/srt/managers/io_struct.py | 14 +- .../srt/managers/multi_tokenizer_mixin.py | 130 ++++++--------- python/sglang/srt/managers/schedule_batch.py | 4 + python/sglang/srt/managers/scheduler.py | 152 ++++++++++-------- .../scheduler_output_processor_mixin.py | 10 +- .../managers/tokenizer_communicator_mixin.py | 10 -- .../sglang/srt/managers/tokenizer_manager.py | 39 +++-- python/sglang/srt/utils/common.py | 4 - 10 files changed, 174 insertions(+), 204 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 982e6467e..129793252 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -149,15 +149,14 @@ def set_global_state(global_state: _GlobalState): async def init_multi_tokenizer() -> ServerArgs: """Read args information from shm and init tokenizer manager for current process""" - pid = os.getpid() - main_pid = get_main_process_id() - logger.info(f"current worker_id: {pid}, main processID: {main_pid}") # Read configuration from shared memory + main_pid = get_main_process_id() port_args, server_args, scheduler_info = read_from_shared_memory( f"multi_tokenizer_args_{main_pid}" ) server_args: ServerArgs + port_args: PortArgs # API key authentication is not supported in multi-tokenizer mode assert ( @@ -167,6 +166,10 @@ async def init_multi_tokenizer() -> ServerArgs: port_args.tokenizer_ipc_name = ( f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" ) + logger.info( + f"Start multi-tokenizer worker process {os.getpid()}, " + f"ipc_name={port_args.tokenizer_ipc_name}" + ) # Launch multi-tokenizer manager process tokenizer_manager = TokenizerWorker(server_args, port_args) @@ -177,8 +180,6 @@ async def init_multi_tokenizer() -> ServerArgs: chat_template=server_args.chat_template, completion_template=server_args.completion_template, ) - # Register this tokenizer with the main tokenizer manager - await tokenizer_manager.register_to_main_tokenizer_manager() tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] set_global_state( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index ab2777ff3..b3c6df7d5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import ( BatchStrOutput, BatchTokenIDOutput, FreezeGCReq, - MultiTokenizerRegisterReq, ) from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin from sglang.srt.server_args import PortArgs, ServerArgs @@ -104,7 +103,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): (BatchEmbeddingOutput, self.handle_batch_embedding_out), (BatchTokenIDOutput, self.handle_batch_token_id_out), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), - (MultiTokenizerRegisterReq, lambda x: x), (FreezeGCReq, self.handle_freeze_gc_req), ] ) @@ -227,6 +225,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): return BatchStrOutput( rids=recv_obj.rids, + http_worker_ipcs=recv_obj.http_worker_ipcs, finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, output_ids=recv_obj.decode_ids, @@ -258,6 +257,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): outputs = self.tokenizer.detokenize(recv_obj) return BatchMultimodalOutput( rids=recv_obj.rids, + http_worker_ipcs=recv_obj.http_worker_ipcs, finished_reasons=recv_obj.finished_reasons, outputs=outputs, prompt_tokens=recv_obj.prompt_tokens, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 4aa07411b..cd67d4dc3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -39,6 +39,7 @@ else: @dataclass class BaseReq(ABC): rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) + http_worker_ipc: Optional[str] = field(default=None, kw_only=True) def regenerate_rid(self): """Generate a new request ID and return it.""" @@ -52,6 +53,7 @@ class BaseReq(ABC): @dataclass class BaseBatchReq(ABC): rids: Optional[List[str]] = field(default=None, kw_only=True) + http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True) def regenerate_rids(self): """Generate new request IDs and return them.""" @@ -1407,18 +1409,6 @@ class LoRAUpdateOutput(BaseReq): LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput -@dataclass -class MultiTokenizerRegisterReq(BaseBatchReq): - ipc_name: Optional[str] = None - - -@dataclass -class MultiTokenizerWrapper: - # FIXME(lsyin): remove this - worker_id: int - obj: Optional[Any] = None - - class BlockReqType(Enum): BLOCK = 1 UNBLOCK = 2 diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 4a3eb1323..8009255a2 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +23,7 @@ import sys import threading from functools import partialmethod from multiprocessing import shared_memory -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Union import setproctitle import zmq @@ -30,12 +32,12 @@ import zmq.asyncio from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( + BaseBatchReq, + BaseReq, BatchEmbeddingOutput, BatchMultimodalOutput, BatchStrOutput, BatchTokenIDOutput, - MultiTokenizerRegisterReq, - MultiTokenizerWrapper, ) from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.utils import get_exception_traceback +if TYPE_CHECKING: + from sglang.srt.managers.detokenizer_manager import DetokenizerManager + logger = logging.getLogger(__name__) @@ -56,29 +61,24 @@ class SocketMapping: socket.close() self._mapping.clear() - def register_ipc_mapping( - self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool - ): + def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool): type_str = "tokenizer" if is_tokenizer else "detokenizer" - if worker_id in self._mapping: - logger.warning( - f"{type_str} already registered with worker {worker_id}, skipping..." - ) + if ipc_name in self._mapping: + logger.warning(f"{type_str} already registered {ipc_name=}, skipping...") return - logger.info( - f"{type_str} not registered with worker {worker_id}, registering..." - ) - socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False) - self._mapping[worker_id] = socket - self._mapping[worker_id].send_pyobj(recv_obj) + logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...") + socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) + self._mapping[ipc_name] = socket - def send_output(self, worker_id: str, output: Any): - if worker_id not in self._mapping: - logger.error( - f"worker ID {worker_id} not registered. Check if the server Process is alive" - ) + def send_output(self, ipc_name: str, output: Any): + if ipc_name is None: + # Some unhandled cases + logger.warning(f"IPC name is None, output type={type(output)}, skipping...") return - self._mapping[worker_id].send_pyobj(output) + + if ipc_name not in self._mapping: + self._register_ipc_mapping(ipc_name, is_tokenizer=False) + self._mapping[ipc_name].send_pyobj(output) def _handle_output_by_index(output, i): @@ -362,20 +362,11 @@ def _handle_output_by_index(output, i): class MultiHttpWorkerDetokenizerMixin: """Mixin class for DetokenizerManager""" - def get_worker_ids_from_req_rids(self, rids): - if isinstance(rids, list): - worker_ids = [int(rid.split("_")[0]) for rid in rids] - elif isinstance(rids, str): - worker_ids = [int(rids.split("_")[0])] - else: - worker_ids = [] - return worker_ids - - def maybe_clear_socket_mapping(self): + def maybe_clear_socket_mapping(self: DetokenizerManager): if hasattr(self, "socket_mapping"): self.socket_mapping.clear_all_sockets() - def multi_http_worker_event_loop(self): + def multi_http_worker_event_loop(self: DetokenizerManager): """The event loop that handles requests, for multi multi-http-worker mode""" self.socket_mapping = SocketMapping() while True: @@ -383,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin: output = self._request_dispatcher(recv_obj) if output is None: continue - # Extract worker_id from rid - if isinstance(recv_obj.rids, list): - worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids) - else: - raise RuntimeError( - f"for tokenizer_worker_num > 1, recv_obj.rids must be a list" - ) + + assert isinstance( + recv_obj, BaseBatchReq + ), "for multi-http-worker, recv_obj must be BaseBatchReq" # Send data using the corresponding socket - for i, worker_id in enumerate(worker_ids): - if isinstance(recv_obj, MultiTokenizerRegisterReq): - self.socket_mapping.register_ipc_mapping( - recv_obj, worker_id, is_tokenizer=False - ) - else: - new_output = _handle_output_by_index(output, i) - self.socket_mapping.send_output(worker_id, new_output) + for i, ipc_name in enumerate(recv_obj.http_worker_ipcs): + new_output = _handle_output_by_index(output, i) + self.socket_mapping.send_output(ipc_name, new_output) class MultiTokenizerRouter: @@ -449,26 +432,17 @@ class MultiTokenizerRouter: await self._distribute_result_to_workers(recv_obj) async def _distribute_result_to_workers(self, recv_obj): - """Distribute result to corresponding workers based on rid""" - if isinstance(recv_obj, MultiTokenizerWrapper): - worker_ids = [recv_obj.worker_id] - recv_obj = recv_obj.obj - else: - worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids) - - if len(worker_ids) == 0: - logger.error(f"Cannot find worker_id from rids {recv_obj.rids}") - return - # Distribute result to each worker - for i, worker_id in enumerate(worker_ids): - if isinstance(recv_obj, MultiTokenizerRegisterReq): - self.socket_mapping.register_ipc_mapping( - recv_obj, worker_id, is_tokenizer=True - ) - else: - new_recv_obj = _handle_output_by_index(recv_obj, i) - self.socket_mapping.send_output(worker_id, new_recv_obj) + if isinstance(recv_obj, BaseReq): + ipc_names = [recv_obj.http_worker_ipc] + elif isinstance(recv_obj, BaseBatchReq): + ipc_names = recv_obj.http_worker_ipcs + else: + raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}") + + for i, ipc_name in enumerate(ipc_names): + new_recv_obj = _handle_output_by_index(recv_obj, i) + self.socket_mapping.send_output(ipc_name, new_recv_obj) class TokenizerWorker(TokenizerManager): @@ -500,21 +474,15 @@ class TokenizerWorker(TokenizerManager): self.register_multi_tokenizer_communicator = _Communicator( self.send_to_scheduler, 2 ) - self._result_dispatcher._mapping.append( - ( - MultiTokenizerRegisterReq, - self.register_multi_tokenizer_communicator.handle_recv, - ) - ) - async def register_to_main_tokenizer_manager(self): - """Register this worker to the main TokenizerManager""" - # create a handle loop to receive messages from the main TokenizerManager - self.auto_create_handle_loop() - req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"]) - req.ipc_name = self.tokenizer_ipc_name - _Communicator.enable_multi_tokenizer = True - await self.register_multi_tokenizer_communicator(req) + def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]): + + if isinstance(req, BaseReq): + req.http_worker_ipc = self.tokenizer_ipc_name + elif isinstance(req, BaseBatchReq): + req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids) + else: + raise ValueError(f"Unknown req type: {type(req)}") async def print_exception_wrapper(func): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 45c104c25..682213e68 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -438,6 +438,7 @@ class Req: priority: Optional[int] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None, extra_key: Optional[str] = None, + http_worker_ipc: Optional[str] = None, ): # Input and output info self.rid = rid @@ -461,6 +462,9 @@ class Req: # The length of KV that have been removed in local attention chunked prefill self.evicted_seqlen_local = 0 + # For multi-http worker + self.http_worker_ipc = http_worker_ipc + # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a64439b46..3a9f92ffc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -24,7 +24,6 @@ from collections import deque from concurrent import futures from dataclasses import dataclass from http import HTTPStatus -from types import SimpleNamespace from typing import Deque, Dict, List, Optional, Tuple, Union import psutil @@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.io_struct import ( AbortReq, + BaseBatchReq, + BaseReq, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, ClearHiCacheReqInput, @@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, - MultiTokenizerRegisterReq, - MultiTokenizerWrapper, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -277,47 +276,7 @@ class Scheduler( self.model_config = ModelConfig.from_server_args(server_args) # Init inter-process communication - context = zmq.Context(2) - self.idle_sleeper = None - if self.pp_rank == 0 and self.attn_tp_rank == 0: - self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, port_args.scheduler_input_ipc_name, False - ) - self.recv_from_rpc = get_zmq_socket( - context, zmq.DEALER, port_args.rpc_ipc_name, False - ) - - self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name, False - ) - if server_args.skip_tokenizer_init: - # Directly send to the TokenizerManager - self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name, False - ) - else: - # Send to the DetokenizerManager - self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.detokenizer_ipc_name, False - ) - - if self.server_args.sleep_on_idle: - self.idle_sleeper = IdleSleeper( - [ - self.recv_from_tokenizer, - self.recv_from_rpc, - ] - ) - else: - self.recv_from_tokenizer = None - self.recv_from_rpc = None - self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) - self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) - - if self.current_scheduler_metrics_enabled(): - self.send_metrics_from_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.metrics_ipc_name, False - ) + self.init_sockets(server_args, port_args) # Init tokenizer self.init_tokenizer() @@ -578,7 +537,6 @@ class Scheduler( (ExpertDistributionReq, self.expert_distribution_handle), (LoadLoRAAdapterReqInput, self.load_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter), - (MultiTokenizerRegisterReq, self.register_multi_tokenizer), (GetLoadReqInput, self.get_load), ] ) @@ -634,6 +592,75 @@ class Scheduler( else: self.draft_worker = None + def init_sockets(self, server_args: ServerArgs, port_args: PortArgs): + context = zmq.Context(2) + self.idle_sleeper = None + + class SenderWrapper: + def __init__(self, socket: zmq.Socket): + self.socket = socket + + def send_output( + self, + output: Union[BaseReq, BaseBatchReq], + recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None, + ): + if self.socket is None: + return + + if ( + isinstance(recv_obj, BaseReq) + and recv_obj.http_worker_ipc is not None + and output.http_worker_ipc is None + ): + # handle communicator reqs for multi-http worker case + output.http_worker_ipc = recv_obj.http_worker_ipc + + self.socket.send_pyobj(output) + + if self.pp_rank == 0 and self.attn_tp_rank == 0: + self.recv_from_tokenizer = get_zmq_socket( + context, zmq.PULL, port_args.scheduler_input_ipc_name, False + ) + self.recv_from_rpc = get_zmq_socket( + context, zmq.DEALER, port_args.rpc_ipc_name, False + ) + + send_to_tokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name, False + ) + if server_args.skip_tokenizer_init: + # Directly send to the TokenizerManager + send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name, False + ) + else: + # Send to the DetokenizerManager + send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.detokenizer_ipc_name, False + ) + + self.send_to_tokenizer = SenderWrapper(send_to_tokenizer) + self.send_to_detokenizer = SenderWrapper(send_to_detokenizer) + + if self.server_args.sleep_on_idle: + self.idle_sleeper = IdleSleeper( + [ + self.recv_from_tokenizer, + self.recv_from_rpc, + ] + ) + else: + self.recv_from_tokenizer = None + self.recv_from_rpc = None + self.send_to_tokenizer = SenderWrapper(None) + self.send_to_detokenizer = SenderWrapper(None) + + if self.current_scheduler_metrics_enabled(): + self.send_metrics_from_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.metrics_ipc_name, False + ) + def init_deterministic_inference_config(self): """Initialize deterministic inference configuration for different attention backends.""" if not self.server_args.enable_deterministic_inference: @@ -1107,23 +1134,13 @@ class Scheduler( self.return_health_check_ct += 1 continue - # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request. - if isinstance(recv_req, MultiTokenizerWrapper): - worker_id = recv_req.worker_id - recv_req = recv_req.obj - output = self._request_dispatcher(recv_req) - if output is not None: - output = MultiTokenizerWrapper(worker_id, output) - self.send_to_tokenizer.send_pyobj(output) - continue - output = self._request_dispatcher(recv_req) if output is not None: if isinstance(output, RpcReqOutput): if self.recv_from_rpc is not None: self.recv_from_rpc.send_pyobj(output) else: - self.send_to_tokenizer.send_pyobj(output) + self.send_to_tokenizer.send_output(output, recv_req) def init_req_max_new_tokens(self, req): req.sampling_params.max_new_tokens = min( @@ -1179,6 +1196,7 @@ class Scheduler( metrics_collector=( self.metrics_collector if self.enable_metrics else None ), + http_worker_ipc=recv_req.http_worker_ipc, ) req.tokenizer = self.tokenizer @@ -1382,7 +1400,7 @@ class Scheduler( }, rid=req.rid, ) - self.send_to_tokenizer.send_pyobj(abort_req) + self.send_to_tokenizer.send_output(abort_req, req) def _abort_on_queued_limit(self, recv_req: Req) -> bool: """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted.""" @@ -1414,7 +1432,7 @@ class Scheduler( req_to_abort = candidate_req message = "The request is aborted by a higher priority request." - self.send_to_tokenizer.send_pyobj( + self.send_to_tokenizer.send_output( AbortReq( finished_reason={ "type": "abort", @@ -1422,7 +1440,8 @@ class Scheduler( "message": message, }, rid=req_to_abort.rid, - ) + ), + req_to_abort, ) return req_to_abort.rid == recv_req.rid @@ -1437,6 +1456,7 @@ class Scheduler( recv_req.sampling_params, token_type_ids=recv_req.token_type_ids, priority=recv_req.priority, + http_worker_ipc=recv_req.http_worker_ipc, ) req.tokenizer = self.tokenizer @@ -1953,8 +1973,8 @@ class Scheduler( self.num_retracted_reqs = len(retracted_reqs) self.new_token_ratio = new_token_ratio for req in reqs_to_abort: - self.send_to_tokenizer.send_pyobj( - AbortReq(abort_reason=req.to_abort_message, rid=req.rid) + self.send_to_tokenizer.send_output( + AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req ) logger.info( @@ -2138,7 +2158,7 @@ class Scheduler( # This is used to prevent the health check signal being blocked by long context prefill. # However, one minor issue is that this code path does not check the status of detokenizer manager. self.return_health_check_ct -= 1 - self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) + self.send_to_tokenizer.send_output(HealthCheckOutput()) def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): return self.prepare_mlp_sync_batch_raw( @@ -2585,7 +2605,7 @@ class Scheduler( if self.enable_hicache_storage: # to release prefetch events associated with the request self.tree_cache.release_aborted_request(req.rid) - self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid)) + self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req) # For disaggregation decode mode, the request in the waiting queue has KV cache allocated. if self.disaggregation_mode == DisaggregationMode.DECODE: self.tree_cache.cache_finished_req(req) @@ -2669,10 +2689,6 @@ class Scheduler( result = self.tp_worker.unload_lora_adapter(recv_req) return result - def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq): - self.send_to_detokenizer.send_pyobj(recv_req) - return recv_req - def init_weights_send_group_for_remote_instance( self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput ): @@ -2751,7 +2767,7 @@ class Scheduler( def handle_freeze_gc(self, recv_req: FreezeGCReq): """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer.""" freeze_gc("Scheduler") - self.send_to_detokenizer.send_pyobj(recv_req) + self.send_to_detokenizer.send_output(recv_req, recv_req) return None diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index b238f6c25..02b62c0e8 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -682,6 +682,7 @@ class SchedulerOutputProcessorMixin: skip_req: Optional[Req] = None, ): rids = [] + http_worker_ipcs = [] finished_reasons: List[BaseFinishReason] = [] decoded_texts = [] @@ -770,6 +771,7 @@ class SchedulerOutputProcessorMixin: req.send_output_token_logprobs_offset ) rids.append(req.rid) + http_worker_ipcs.append(req.http_worker_ipc) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None ) @@ -886,7 +888,7 @@ class SchedulerOutputProcessorMixin: if self.model_config.is_multimodal_gen: return - self.send_to_detokenizer.send_pyobj( + self.send_to_detokenizer.send_output( BatchTokenIDOutput( finished_reasons, decoded_texts, @@ -916,6 +918,7 @@ class SchedulerOutputProcessorMixin: output_token_entropy_val=None, output_hidden_states=output_hidden_states, rids=rids, + http_worker_ipcs=http_worker_ipcs, placeholder_tokens_idx=None, placeholder_tokens_val=None, ) @@ -923,6 +926,7 @@ class SchedulerOutputProcessorMixin: def stream_output_embedding(self: Scheduler, reqs: List[Req]): rids = [] + http_worker_ipcs = [] finished_reasons: List[BaseFinishReason] = [] embeddings = [] @@ -931,17 +935,19 @@ class SchedulerOutputProcessorMixin: for req in reqs: if req.finished(): rids.append(req.rid) + http_worker_ipcs.append(req.http_worker_ipc) finished_reasons.append(req.finished_reason.to_json()) embeddings.append(req.embedding) prompt_tokens.append(len(req.origin_input_ids)) cached_tokens.append(req.cached_tokens) - self.send_to_detokenizer.send_pyobj( + self.send_to_detokenizer.send_output( BatchEmbeddingOutput( finished_reasons, embeddings, prompt_tokens, cached_tokens, rids=rids, + http_worker_ipcs=http_worker_ipcs, placeholder_tokens_idx=None, placeholder_tokens_val=None, ) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index d40766b1d..c0283d05d 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import copy import logging -import os import time import uuid from collections import deque @@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import ( LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, LoRAUpdateOutput, - MultiTokenizerWrapper, OpenSessionReqInput, ProfileReq, ProfileReqOutput, @@ -83,8 +81,6 @@ logger = logging.getLogger(__name__) class _Communicator(Generic[T]): """Note: The communicator now only run up to 1 in-flight request at any time.""" - enable_multi_tokenizer = False - def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"): self._sender = sender self._fan_out = fan_out @@ -104,8 +100,6 @@ class _Communicator(Generic[T]): assert self._result_values is None if obj: - if _Communicator.enable_multi_tokenizer: - obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj) self._sender.send_pyobj(obj) self._result_event = asyncio.Event() @@ -126,8 +120,6 @@ class _Communicator(Generic[T]): self._result_event = asyncio.Event() if obj: - if _Communicator.enable_multi_tokenizer: - obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj) self._sender.send_pyobj(obj) await self._result_event.wait() @@ -617,8 +609,6 @@ class TokenizerCommunicatorMixin: elif obj.session_id in self.session_futures: return None - if self.server_args.tokenizer_worker_num > 1: - obj = MultiTokenizerWrapper(self.worker_id, obj) self.send_to_scheduler.send_pyobj(obj) self.session_futures[obj.session_id] = asyncio.Future() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4b90411d5..63eaaa268 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, + BaseReq, BatchEmbeddingOutput, BatchMultimodalOutput, BatchStrOutput, @@ -58,7 +59,6 @@ from sglang.srt.managers.io_struct import ( GenerateReqInput, GetLoadReqInput, HealthCheckOutput, - MultiTokenizerWrapper, OpenSessionReqOutput, SessionParams, TokenizedEmbeddingReqInput, @@ -88,7 +88,6 @@ from sglang.srt.utils import ( dataclass_to_string_truncated, freeze_gc, get_bool_env_var, - get_origin_rid, get_zmq_socket, kill_process_tree, ) @@ -258,9 +257,18 @@ class TokenizerManager(TokenizerCommunicatorMixin): ) if self.server_args.tokenizer_worker_num > 1: # Use tokenizer_worker_ipc_name in multi-tokenizer mode - self.send_to_scheduler = get_zmq_socket( + send_to_scheduler = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False ) + + class SenderWrapper: + def send_pyobj(self, obj): + if isinstance(obj, BaseReq): + obj.http_worker_ipc = port_args.tokenizer_ipc_name + send_to_scheduler.send_pyobj(obj) + + # Make sure that each request carries the tokenizer_ipc_name for response routing + self.send_to_scheduler = SenderWrapper() else: self.send_to_scheduler = get_zmq_socket( context, zmq.PUSH, port_args.scheduler_input_ipc_name, True @@ -376,13 +384,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): obj.normalize_batch_and_arguments() if self.server_args.tokenizer_worker_num > 1: - # Modify rid, add worker_id - if isinstance(obj.rid, list): - # If it's an array, add worker_id prefix to each element - obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid] - else: - # If it's a single value, add worker_id prefix - obj.rid = f"{self.worker_id}_{obj.rid}" + from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker + + assert isinstance(self, TokenizerWorker) + self._attach_multi_http_worker_info(obj) if self.enable_trace: self._trace_request_start(obj, created_time) @@ -728,6 +733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): obj.token_ids_logprob, obj.stream, rid=obj.rid, + http_worker_ipc=obj.http_worker_ipc, bootstrap_host=obj.bootstrap_host, bootstrap_port=obj.bootstrap_port, bootstrap_room=obj.bootstrap_room, @@ -749,6 +755,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): sampling_params, rid=obj.rid, priority=obj.priority, + http_worker_ipc=obj.http_worker_ipc, ) return tokenized_obj @@ -1109,8 +1116,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): async def _wait_for_model_update_from_disk( self, obj: UpdateWeightFromDiskReqInput ) -> Tuple[bool, str]: - if self.server_args.tokenizer_worker_num > 1: - obj = MultiTokenizerWrapper(self.worker_id, obj) self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future() if self.server_args.dp_size == 1: @@ -1349,12 +1354,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ) continue - origin_rid = rid - if self.server_args.tokenizer_worker_num > 1: - origin_rid = get_origin_rid(rid) # Build meta_info and return value meta_info = { - "id": origin_rid, + "id": rid, "finish_reason": recv_obj.finished_reasons[i], "prompt_tokens": recv_obj.prompt_tokens[i], "weight_version": self.server_args.weight_version, @@ -1708,9 +1710,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): if is_health_check_generate_req(recv_obj): return state = self.rid_to_state[recv_obj.rid] - origin_rid = recv_obj.rid - if self.server_args.tokenizer_worker_num > 1: - origin_rid = get_origin_rid(origin_rid) state.finished = True if recv_obj.finished_reason: out = { @@ -1723,7 +1722,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): out = { "text": "", "meta_info": { - "id": origin_rid, + "id": recv_obj.rid, "finish_reason": { "type": "abort", "message": "Abort before prefill", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 148a73bf8..42fb93374 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3006,10 +3006,6 @@ def lru_cache_frozenset(maxsize=128): return decorator -def get_origin_rid(rid): - return rid.split("_", 1)[1] if "_" in rid else rid - - def apply_module_patch(target_module, target_function, wrappers): original_module, original_function = parse_module_path( target_module, target_function, False