Simplify multi-tokenizer (#11295)
Signed-off-by: zhengkezhou1 <madzhou1@gmail.com> Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
@@ -24,7 +24,6 @@ from collections import deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from types import SimpleNamespace
|
||||
from typing import Deque, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BaseBatchReq,
|
||||
BaseReq,
|
||||
BatchTokenizedEmbeddingReqInput,
|
||||
BatchTokenizedGenerateReqInput,
|
||||
ClearHiCacheReqInput,
|
||||
@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -277,47 +276,7 @@ class Scheduler(
|
||||
self.model_config = ModelConfig.from_server_args(server_args)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
)
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
|
||||
self.send_to_tokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the TokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
else:
|
||||
# Send to the DetokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
if self.server_args.sleep_on_idle:
|
||||
self.idle_sleeper = IdleSleeper(
|
||||
[
|
||||
self.recv_from_tokenizer,
|
||||
self.recv_from_rpc,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
|
||||
if self.current_scheduler_metrics_enabled():
|
||||
self.send_metrics_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
||||
)
|
||||
self.init_sockets(server_args, port_args)
|
||||
|
||||
# Init tokenizer
|
||||
self.init_tokenizer()
|
||||
@@ -578,7 +537,6 @@ class Scheduler(
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||
(GetLoadReqInput, self.get_load),
|
||||
]
|
||||
)
|
||||
@@ -634,6 +592,75 @@ class Scheduler(
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
|
||||
class SenderWrapper:
|
||||
def __init__(self, socket: zmq.Socket):
|
||||
self.socket = socket
|
||||
|
||||
def send_output(
|
||||
self,
|
||||
output: Union[BaseReq, BaseBatchReq],
|
||||
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
|
||||
):
|
||||
if self.socket is None:
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(recv_obj, BaseReq)
|
||||
and recv_obj.http_worker_ipc is not None
|
||||
and output.http_worker_ipc is None
|
||||
):
|
||||
# handle communicator reqs for multi-http worker case
|
||||
output.http_worker_ipc = recv_obj.http_worker_ipc
|
||||
|
||||
self.socket.send_pyobj(output)
|
||||
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
)
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
|
||||
send_to_tokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the TokenizerManager
|
||||
send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
else:
|
||||
# Send to the DetokenizerManager
|
||||
send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
|
||||
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
|
||||
|
||||
if self.server_args.sleep_on_idle:
|
||||
self.idle_sleeper = IdleSleeper(
|
||||
[
|
||||
self.recv_from_tokenizer,
|
||||
self.recv_from_rpc,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SenderWrapper(None)
|
||||
self.send_to_detokenizer = SenderWrapper(None)
|
||||
|
||||
if self.current_scheduler_metrics_enabled():
|
||||
self.send_metrics_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
||||
)
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
@@ -1107,23 +1134,13 @@ class Scheduler(
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
||||
if isinstance(recv_req, MultiTokenizerWrapper):
|
||||
worker_id = recv_req.worker_id
|
||||
recv_req = recv_req.obj
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
output = MultiTokenizerWrapper(worker_id, output)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
continue
|
||||
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
if isinstance(output, RpcReqOutput):
|
||||
if self.recv_from_rpc is not None:
|
||||
self.recv_from_rpc.send_pyobj(output)
|
||||
else:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
self.send_to_tokenizer.send_output(output, recv_req)
|
||||
|
||||
def init_req_max_new_tokens(self, req):
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
@@ -1179,6 +1196,7 @@ class Scheduler(
|
||||
metrics_collector=(
|
||||
self.metrics_collector if self.enable_metrics else None
|
||||
),
|
||||
http_worker_ipc=recv_req.http_worker_ipc,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1382,7 +1400,7 @@ class Scheduler(
|
||||
},
|
||||
rid=req.rid,
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
self.send_to_tokenizer.send_output(abort_req, req)
|
||||
|
||||
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
||||
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
||||
@@ -1414,7 +1432,7 @@ class Scheduler(
|
||||
req_to_abort = candidate_req
|
||||
message = "The request is aborted by a higher priority request."
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
self.send_to_tokenizer.send_output(
|
||||
AbortReq(
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
@@ -1422,7 +1440,8 @@ class Scheduler(
|
||||
"message": message,
|
||||
},
|
||||
rid=req_to_abort.rid,
|
||||
)
|
||||
),
|
||||
req_to_abort,
|
||||
)
|
||||
return req_to_abort.rid == recv_req.rid
|
||||
|
||||
@@ -1437,6 +1456,7 @@ class Scheduler(
|
||||
recv_req.sampling_params,
|
||||
token_type_ids=recv_req.token_type_ids,
|
||||
priority=recv_req.priority,
|
||||
http_worker_ipc=recv_req.http_worker_ipc,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1953,8 +1973,8 @@ class Scheduler(
|
||||
self.num_retracted_reqs = len(retracted_reqs)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
for req in reqs_to_abort:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
||||
self.send_to_tokenizer.send_output(
|
||||
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -2138,7 +2158,7 @@ class Scheduler(
|
||||
# This is used to prevent the health check signal being blocked by long context prefill.
|
||||
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
||||
self.return_health_check_ct -= 1
|
||||
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
||||
self.send_to_tokenizer.send_output(HealthCheckOutput())
|
||||
|
||||
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
||||
return self.prepare_mlp_sync_batch_raw(
|
||||
@@ -2585,7 +2605,7 @@ class Scheduler(
|
||||
if self.enable_hicache_storage:
|
||||
# to release prefetch events associated with the request
|
||||
self.tree_cache.release_aborted_request(req.rid)
|
||||
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
||||
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
|
||||
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
@@ -2669,10 +2689,6 @@ class Scheduler(
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
return result
|
||||
|
||||
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
return recv_req
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
@@ -2751,7 +2767,7 @@ class Scheduler(
|
||||
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
||||
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
||||
freeze_gc("Scheduler")
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
self.send_to_detokenizer.send_output(recv_req, recv_req)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user