Simplify multi-tokenizer (#11295)

Signed-off-by: zhengkezhou1 <madzhou1@gmail.com>
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
Zhengke Zhou
2025-10-21 16:33:29 +08:00
committed by GitHub
parent dbb16bedd5
commit 260fe755b6
10 changed files with 174 additions and 204 deletions

View File

@@ -24,7 +24,6 @@ from collections import deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil
@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import (
AbortReq,
BaseBatchReq,
BaseReq,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput,
@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -277,47 +276,7 @@ class Scheduler(
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
self.init_sockets(server_args, port_args)
# Init tokenizer
self.init_tokenizer()
@@ -578,7 +537,6 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
(GetLoadReqInput, self.get_load),
]
)
@@ -634,6 +592,75 @@ class Scheduler(
else:
self.draft_worker = None
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
context = zmq.Context(2)
self.idle_sleeper = None
class SenderWrapper:
def __init__(self, socket: zmq.Socket):
self.socket = socket
def send_output(
self,
output: Union[BaseReq, BaseBatchReq],
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
):
if self.socket is None:
return
if (
isinstance(recv_obj, BaseReq)
and recv_obj.http_worker_ipc is not None
and output.http_worker_ipc is None
):
# handle communicator reqs for multi-http worker case
output.http_worker_ipc = recv_obj.http_worker_ipc
self.socket.send_pyobj(output)
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
self.recv_from_tokenizer,
self.recv_from_rpc,
]
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SenderWrapper(None)
self.send_to_detokenizer = SenderWrapper(None)
if self.current_scheduler_metrics_enabled():
self.send_metrics_from_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.metrics_ipc_name, False
)
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
@@ -1107,23 +1134,13 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWrapper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
self.send_to_tokenizer.send_output(output, recv_req)
def init_req_max_new_tokens(self, req):
req.sampling_params.max_new_tokens = min(
@@ -1179,6 +1196,7 @@ class Scheduler(
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
http_worker_ipc=recv_req.http_worker_ipc,
)
req.tokenizer = self.tokenizer
@@ -1382,7 +1400,7 @@ class Scheduler(
},
rid=req.rid,
)
self.send_to_tokenizer.send_pyobj(abort_req)
self.send_to_tokenizer.send_output(abort_req, req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
@@ -1414,7 +1432,7 @@ class Scheduler(
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
self.send_to_tokenizer.send_output(
AbortReq(
finished_reason={
"type": "abort",
@@ -1422,7 +1440,8 @@ class Scheduler(
"message": message,
},
rid=req_to_abort.rid,
)
),
req_to_abort,
)
return req_to_abort.rid == recv_req.rid
@@ -1437,6 +1456,7 @@ class Scheduler(
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
http_worker_ipc=recv_req.http_worker_ipc,
)
req.tokenizer = self.tokenizer
@@ -1953,8 +1973,8 @@ class Scheduler(
self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
self.send_to_tokenizer.send_output(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
)
logger.info(
@@ -2138,7 +2158,7 @@ class Scheduler(
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
self.send_to_tokenizer.send_output(HealthCheckOutput())
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw(
@@ -2585,7 +2605,7 @@ class Scheduler(
if self.enable_hicache_storage:
# to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req)
@@ -2669,10 +2689,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
@@ -2751,7 +2767,7 @@ class Scheduler(
def handle_freeze_gc(self, recv_req: FreezeGCReq):
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
freeze_gc("Scheduler")
self.send_to_detokenizer.send_pyobj(recv_req)
self.send_to_detokenizer.send_output(recv_req, recv_req)
return None