Simplify multi-tokenizer (#11295)
Signed-off-by: zhengkezhou1 <madzhou1@gmail.com> Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
@@ -149,15 +149,14 @@ def set_global_state(global_state: _GlobalState):
|
||||
|
||||
async def init_multi_tokenizer() -> ServerArgs:
|
||||
"""Read args information from shm and init tokenizer manager for current process"""
|
||||
pid = os.getpid()
|
||||
main_pid = get_main_process_id()
|
||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
||||
|
||||
# Read configuration from shared memory
|
||||
main_pid = get_main_process_id()
|
||||
port_args, server_args, scheduler_info = read_from_shared_memory(
|
||||
f"multi_tokenizer_args_{main_pid}"
|
||||
)
|
||||
server_args: ServerArgs
|
||||
port_args: PortArgs
|
||||
|
||||
# API key authentication is not supported in multi-tokenizer mode
|
||||
assert (
|
||||
@@ -167,6 +166,10 @@ async def init_multi_tokenizer() -> ServerArgs:
|
||||
port_args.tokenizer_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
)
|
||||
logger.info(
|
||||
f"Start multi-tokenizer worker process {os.getpid()}, "
|
||||
f"ipc_name={port_args.tokenizer_ipc_name}"
|
||||
)
|
||||
|
||||
# Launch multi-tokenizer manager process
|
||||
tokenizer_manager = TokenizerWorker(server_args, port_args)
|
||||
@@ -177,8 +180,6 @@ async def init_multi_tokenizer() -> ServerArgs:
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
# Register this tokenizer with the main tokenizer manager
|
||||
await tokenizer_manager.register_to_main_tokenizer_manager()
|
||||
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
set_global_state(
|
||||
|
||||
@@ -31,7 +31,6 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchStrOutput,
|
||||
BatchTokenIDOutput,
|
||||
FreezeGCReq,
|
||||
MultiTokenizerRegisterReq,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -104,7 +103,6 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOutput, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
(MultiTokenizerRegisterReq, lambda x: x),
|
||||
(FreezeGCReq, self.handle_freeze_gc_req),
|
||||
]
|
||||
)
|
||||
@@ -227,6 +225,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
|
||||
return BatchStrOutput(
|
||||
rids=recv_obj.rids,
|
||||
http_worker_ipcs=recv_obj.http_worker_ipcs,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
output_ids=recv_obj.decode_ids,
|
||||
@@ -258,6 +257,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
outputs = self.tokenizer.detokenize(recv_obj)
|
||||
return BatchMultimodalOutput(
|
||||
rids=recv_obj.rids,
|
||||
http_worker_ipcs=recv_obj.http_worker_ipcs,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
outputs=outputs,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
|
||||
@@ -39,6 +39,7 @@ else:
|
||||
@dataclass
|
||||
class BaseReq(ABC):
|
||||
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
|
||||
http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
|
||||
|
||||
def regenerate_rid(self):
|
||||
"""Generate a new request ID and return it."""
|
||||
@@ -52,6 +53,7 @@ class BaseReq(ABC):
|
||||
@dataclass
|
||||
class BaseBatchReq(ABC):
|
||||
rids: Optional[List[str]] = field(default=None, kw_only=True)
|
||||
http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
|
||||
|
||||
def regenerate_rids(self):
|
||||
"""Generate new request IDs and return them."""
|
||||
@@ -1407,18 +1409,6 @@ class LoRAUpdateOutput(BaseReq):
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerRegisterReq(BaseBatchReq):
|
||||
ipc_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerWrapper:
|
||||
# FIXME(lsyin): remove this
|
||||
worker_id: int
|
||||
obj: Optional[Any] = None
|
||||
|
||||
|
||||
class BlockReqType(Enum):
|
||||
BLOCK = 1
|
||||
UNBLOCK = 2
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -21,7 +23,7 @@ import sys
|
||||
import threading
|
||||
from functools import partialmethod
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Union
|
||||
|
||||
import setproctitle
|
||||
import zmq
|
||||
@@ -30,12 +32,12 @@ import zmq.asyncio
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BaseBatchReq,
|
||||
BaseReq,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchStrOutput,
|
||||
BatchTokenIDOutput,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWrapper,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -43,6 +45,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.detokenizer_manager import DetokenizerManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,29 +61,24 @@ class SocketMapping:
|
||||
socket.close()
|
||||
self._mapping.clear()
|
||||
|
||||
def register_ipc_mapping(
|
||||
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
|
||||
):
|
||||
def _register_ipc_mapping(self, ipc_name: str, is_tokenizer: bool):
|
||||
type_str = "tokenizer" if is_tokenizer else "detokenizer"
|
||||
if worker_id in self._mapping:
|
||||
logger.warning(
|
||||
f"{type_str} already registered with worker {worker_id}, skipping..."
|
||||
)
|
||||
if ipc_name in self._mapping:
|
||||
logger.warning(f"{type_str} already registered {ipc_name=}, skipping...")
|
||||
return
|
||||
logger.info(
|
||||
f"{type_str} not registered with worker {worker_id}, registering..."
|
||||
)
|
||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
|
||||
self._mapping[worker_id] = socket
|
||||
self._mapping[worker_id].send_pyobj(recv_obj)
|
||||
logger.info(f"Registering {type_str} {ipc_name=} in SocketMapping...")
|
||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
||||
self._mapping[ipc_name] = socket
|
||||
|
||||
def send_output(self, worker_id: str, output: Any):
|
||||
if worker_id not in self._mapping:
|
||||
logger.error(
|
||||
f"worker ID {worker_id} not registered. Check if the server Process is alive"
|
||||
)
|
||||
def send_output(self, ipc_name: str, output: Any):
|
||||
if ipc_name is None:
|
||||
# Some unhandled cases
|
||||
logger.warning(f"IPC name is None, output type={type(output)}, skipping...")
|
||||
return
|
||||
self._mapping[worker_id].send_pyobj(output)
|
||||
|
||||
if ipc_name not in self._mapping:
|
||||
self._register_ipc_mapping(ipc_name, is_tokenizer=False)
|
||||
self._mapping[ipc_name].send_pyobj(output)
|
||||
|
||||
|
||||
def _handle_output_by_index(output, i):
|
||||
@@ -362,20 +362,11 @@ def _handle_output_by_index(output, i):
|
||||
class MultiHttpWorkerDetokenizerMixin:
|
||||
"""Mixin class for DetokenizerManager"""
|
||||
|
||||
def get_worker_ids_from_req_rids(self, rids):
|
||||
if isinstance(rids, list):
|
||||
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
||||
elif isinstance(rids, str):
|
||||
worker_ids = [int(rids.split("_")[0])]
|
||||
else:
|
||||
worker_ids = []
|
||||
return worker_ids
|
||||
|
||||
def maybe_clear_socket_mapping(self):
|
||||
def maybe_clear_socket_mapping(self: DetokenizerManager):
|
||||
if hasattr(self, "socket_mapping"):
|
||||
self.socket_mapping.clear_all_sockets()
|
||||
|
||||
def multi_http_worker_event_loop(self):
|
||||
def multi_http_worker_event_loop(self: DetokenizerManager):
|
||||
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
||||
self.socket_mapping = SocketMapping()
|
||||
while True:
|
||||
@@ -383,23 +374,15 @@ class MultiHttpWorkerDetokenizerMixin:
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
if output is None:
|
||||
continue
|
||||
# Extract worker_id from rid
|
||||
if isinstance(recv_obj.rids, list):
|
||||
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, BaseBatchReq
|
||||
), "for multi-http-worker, recv_obj must be BaseBatchReq"
|
||||
|
||||
# Send data using the corresponding socket
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
self.socket_mapping.register_ipc_mapping(
|
||||
recv_obj, worker_id, is_tokenizer=False
|
||||
)
|
||||
else:
|
||||
new_output = _handle_output_by_index(output, i)
|
||||
self.socket_mapping.send_output(worker_id, new_output)
|
||||
for i, ipc_name in enumerate(recv_obj.http_worker_ipcs):
|
||||
new_output = _handle_output_by_index(output, i)
|
||||
self.socket_mapping.send_output(ipc_name, new_output)
|
||||
|
||||
|
||||
class MultiTokenizerRouter:
|
||||
@@ -449,26 +432,17 @@ class MultiTokenizerRouter:
|
||||
await self._distribute_result_to_workers(recv_obj)
|
||||
|
||||
async def _distribute_result_to_workers(self, recv_obj):
|
||||
"""Distribute result to corresponding workers based on rid"""
|
||||
if isinstance(recv_obj, MultiTokenizerWrapper):
|
||||
worker_ids = [recv_obj.worker_id]
|
||||
recv_obj = recv_obj.obj
|
||||
else:
|
||||
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
||||
|
||||
if len(worker_ids) == 0:
|
||||
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
||||
return
|
||||
|
||||
# Distribute result to each worker
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
self.socket_mapping.register_ipc_mapping(
|
||||
recv_obj, worker_id, is_tokenizer=True
|
||||
)
|
||||
else:
|
||||
new_recv_obj = _handle_output_by_index(recv_obj, i)
|
||||
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
||||
if isinstance(recv_obj, BaseReq):
|
||||
ipc_names = [recv_obj.http_worker_ipc]
|
||||
elif isinstance(recv_obj, BaseBatchReq):
|
||||
ipc_names = recv_obj.http_worker_ipcs
|
||||
else:
|
||||
raise ValueError(f"Unknown recv_obj type: {type(recv_obj)}")
|
||||
|
||||
for i, ipc_name in enumerate(ipc_names):
|
||||
new_recv_obj = _handle_output_by_index(recv_obj, i)
|
||||
self.socket_mapping.send_output(ipc_name, new_recv_obj)
|
||||
|
||||
|
||||
class TokenizerWorker(TokenizerManager):
|
||||
@@ -500,21 +474,15 @@ class TokenizerWorker(TokenizerManager):
|
||||
self.register_multi_tokenizer_communicator = _Communicator(
|
||||
self.send_to_scheduler, 2
|
||||
)
|
||||
self._result_dispatcher._mapping.append(
|
||||
(
|
||||
MultiTokenizerRegisterReq,
|
||||
self.register_multi_tokenizer_communicator.handle_recv,
|
||||
)
|
||||
)
|
||||
|
||||
async def register_to_main_tokenizer_manager(self):
|
||||
"""Register this worker to the main TokenizerManager"""
|
||||
# create a handle loop to receive messages from the main TokenizerManager
|
||||
self.auto_create_handle_loop()
|
||||
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
|
||||
req.ipc_name = self.tokenizer_ipc_name
|
||||
_Communicator.enable_multi_tokenizer = True
|
||||
await self.register_multi_tokenizer_communicator(req)
|
||||
def _attach_multi_http_worker_info(self, req: Union[BaseReq, BaseBatchReq]):
|
||||
|
||||
if isinstance(req, BaseReq):
|
||||
req.http_worker_ipc = self.tokenizer_ipc_name
|
||||
elif isinstance(req, BaseBatchReq):
|
||||
req.http_worker_ipcs = [self.tokenizer_ipc_name] * len(req.rids)
|
||||
else:
|
||||
raise ValueError(f"Unknown req type: {type(req)}")
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
|
||||
@@ -438,6 +438,7 @@ class Req:
|
||||
priority: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
extra_key: Optional[str] = None,
|
||||
http_worker_ipc: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -461,6 +462,9 @@ class Req:
|
||||
# The length of KV that have been removed in local attention chunked prefill
|
||||
self.evicted_seqlen_local = 0
|
||||
|
||||
# For multi-http worker
|
||||
self.http_worker_ipc = http_worker_ipc
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
|
||||
@@ -24,7 +24,6 @@ from collections import deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from types import SimpleNamespace
|
||||
from typing import Deque, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
@@ -66,6 +65,8 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BaseBatchReq,
|
||||
BaseReq,
|
||||
BatchTokenizedEmbeddingReqInput,
|
||||
BatchTokenizedGenerateReqInput,
|
||||
ClearHiCacheReqInput,
|
||||
@@ -89,8 +90,6 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -277,47 +276,7 @@ class Scheduler(
|
||||
self.model_config = ModelConfig.from_server_args(server_args)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
)
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
|
||||
self.send_to_tokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the TokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
else:
|
||||
# Send to the DetokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
if self.server_args.sleep_on_idle:
|
||||
self.idle_sleeper = IdleSleeper(
|
||||
[
|
||||
self.recv_from_tokenizer,
|
||||
self.recv_from_rpc,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
|
||||
if self.current_scheduler_metrics_enabled():
|
||||
self.send_metrics_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
||||
)
|
||||
self.init_sockets(server_args, port_args)
|
||||
|
||||
# Init tokenizer
|
||||
self.init_tokenizer()
|
||||
@@ -578,7 +537,6 @@ class Scheduler(
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||
(GetLoadReqInput, self.get_load),
|
||||
]
|
||||
)
|
||||
@@ -634,6 +592,75 @@ class Scheduler(
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
|
||||
class SenderWrapper:
|
||||
def __init__(self, socket: zmq.Socket):
|
||||
self.socket = socket
|
||||
|
||||
def send_output(
|
||||
self,
|
||||
output: Union[BaseReq, BaseBatchReq],
|
||||
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
|
||||
):
|
||||
if self.socket is None:
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(recv_obj, BaseReq)
|
||||
and recv_obj.http_worker_ipc is not None
|
||||
and output.http_worker_ipc is None
|
||||
):
|
||||
# handle communicator reqs for multi-http worker case
|
||||
output.http_worker_ipc = recv_obj.http_worker_ipc
|
||||
|
||||
self.socket.send_pyobj(output)
|
||||
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
)
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
|
||||
send_to_tokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the TokenizerManager
|
||||
send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
else:
|
||||
# Send to the DetokenizerManager
|
||||
send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
|
||||
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
|
||||
|
||||
if self.server_args.sleep_on_idle:
|
||||
self.idle_sleeper = IdleSleeper(
|
||||
[
|
||||
self.recv_from_tokenizer,
|
||||
self.recv_from_rpc,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SenderWrapper(None)
|
||||
self.send_to_detokenizer = SenderWrapper(None)
|
||||
|
||||
if self.current_scheduler_metrics_enabled():
|
||||
self.send_metrics_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
||||
)
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
@@ -1107,23 +1134,13 @@ class Scheduler(
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
||||
if isinstance(recv_req, MultiTokenizerWrapper):
|
||||
worker_id = recv_req.worker_id
|
||||
recv_req = recv_req.obj
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
output = MultiTokenizerWrapper(worker_id, output)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
continue
|
||||
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
if isinstance(output, RpcReqOutput):
|
||||
if self.recv_from_rpc is not None:
|
||||
self.recv_from_rpc.send_pyobj(output)
|
||||
else:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
self.send_to_tokenizer.send_output(output, recv_req)
|
||||
|
||||
def init_req_max_new_tokens(self, req):
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
@@ -1179,6 +1196,7 @@ class Scheduler(
|
||||
metrics_collector=(
|
||||
self.metrics_collector if self.enable_metrics else None
|
||||
),
|
||||
http_worker_ipc=recv_req.http_worker_ipc,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1382,7 +1400,7 @@ class Scheduler(
|
||||
},
|
||||
rid=req.rid,
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
self.send_to_tokenizer.send_output(abort_req, req)
|
||||
|
||||
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
||||
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
||||
@@ -1414,7 +1432,7 @@ class Scheduler(
|
||||
req_to_abort = candidate_req
|
||||
message = "The request is aborted by a higher priority request."
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
self.send_to_tokenizer.send_output(
|
||||
AbortReq(
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
@@ -1422,7 +1440,8 @@ class Scheduler(
|
||||
"message": message,
|
||||
},
|
||||
rid=req_to_abort.rid,
|
||||
)
|
||||
),
|
||||
req_to_abort,
|
||||
)
|
||||
return req_to_abort.rid == recv_req.rid
|
||||
|
||||
@@ -1437,6 +1456,7 @@ class Scheduler(
|
||||
recv_req.sampling_params,
|
||||
token_type_ids=recv_req.token_type_ids,
|
||||
priority=recv_req.priority,
|
||||
http_worker_ipc=recv_req.http_worker_ipc,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -1953,8 +1973,8 @@ class Scheduler(
|
||||
self.num_retracted_reqs = len(retracted_reqs)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
for req in reqs_to_abort:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
||||
self.send_to_tokenizer.send_output(
|
||||
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -2138,7 +2158,7 @@ class Scheduler(
|
||||
# This is used to prevent the health check signal being blocked by long context prefill.
|
||||
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
||||
self.return_health_check_ct -= 1
|
||||
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
||||
self.send_to_tokenizer.send_output(HealthCheckOutput())
|
||||
|
||||
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
||||
return self.prepare_mlp_sync_batch_raw(
|
||||
@@ -2585,7 +2605,7 @@ class Scheduler(
|
||||
if self.enable_hicache_storage:
|
||||
# to release prefetch events associated with the request
|
||||
self.tree_cache.release_aborted_request(req.rid)
|
||||
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
||||
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
|
||||
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
@@ -2669,10 +2689,6 @@ class Scheduler(
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
return result
|
||||
|
||||
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
return recv_req
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
@@ -2751,7 +2767,7 @@ class Scheduler(
|
||||
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
||||
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
||||
freeze_gc("Scheduler")
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
self.send_to_detokenizer.send_output(recv_req, recv_req)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -682,6 +682,7 @@ class SchedulerOutputProcessorMixin:
|
||||
skip_req: Optional[Req] = None,
|
||||
):
|
||||
rids = []
|
||||
http_worker_ipcs = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
decoded_texts = []
|
||||
@@ -770,6 +771,7 @@ class SchedulerOutputProcessorMixin:
|
||||
req.send_output_token_logprobs_offset
|
||||
)
|
||||
rids.append(req.rid)
|
||||
http_worker_ipcs.append(req.http_worker_ipc)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
)
|
||||
@@ -886,7 +888,7 @@ class SchedulerOutputProcessorMixin:
|
||||
if self.model_config.is_multimodal_gen:
|
||||
return
|
||||
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
self.send_to_detokenizer.send_output(
|
||||
BatchTokenIDOutput(
|
||||
finished_reasons,
|
||||
decoded_texts,
|
||||
@@ -916,6 +918,7 @@ class SchedulerOutputProcessorMixin:
|
||||
output_token_entropy_val=None,
|
||||
output_hidden_states=output_hidden_states,
|
||||
rids=rids,
|
||||
http_worker_ipcs=http_worker_ipcs,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
@@ -923,6 +926,7 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
||||
rids = []
|
||||
http_worker_ipcs = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
embeddings = []
|
||||
@@ -931,17 +935,19 @@ class SchedulerOutputProcessorMixin:
|
||||
for req in reqs:
|
||||
if req.finished():
|
||||
rids.append(req.rid)
|
||||
http_worker_ipcs.append(req.http_worker_ipc)
|
||||
finished_reasons.append(req.finished_reason.to_json())
|
||||
embeddings.append(req.embedding)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
self.send_to_detokenizer.send_output(
|
||||
BatchEmbeddingOutput(
|
||||
finished_reasons,
|
||||
embeddings,
|
||||
prompt_tokens,
|
||||
cached_tokens,
|
||||
rids=rids,
|
||||
http_worker_ipcs=http_worker_ipcs,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
@@ -46,7 +45,6 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateOutput,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
@@ -83,8 +81,6 @@ logger = logging.getLogger(__name__)
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
@@ -104,8 +100,6 @@ class _Communicator(Generic[T]):
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
@@ -126,8 +120,6 @@ class _Communicator(Generic[T]):
|
||||
self._result_event = asyncio.Event()
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
await self._result_event.wait()
|
||||
@@ -617,8 +609,6 @@ class TokenizerCommunicatorMixin:
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BaseReq,
|
||||
BatchEmbeddingOutput,
|
||||
BatchMultimodalOutput,
|
||||
BatchStrOutput,
|
||||
@@ -58,7 +59,6 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
GetLoadReqInput,
|
||||
HealthCheckOutput,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqOutput,
|
||||
SessionParams,
|
||||
TokenizedEmbeddingReqInput,
|
||||
@@ -88,7 +88,6 @@ from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
freeze_gc,
|
||||
get_bool_env_var,
|
||||
get_origin_rid,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
@@ -258,9 +257,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
)
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
||||
)
|
||||
|
||||
class SenderWrapper:
|
||||
def send_pyobj(self, obj):
|
||||
if isinstance(obj, BaseReq):
|
||||
obj.http_worker_ipc = port_args.tokenizer_ipc_name
|
||||
send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
# Make sure that each request carries the tokenizer_ipc_name for response routing
|
||||
self.send_to_scheduler = SenderWrapper()
|
||||
else:
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
@@ -376,13 +384,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Modify rid, add worker_id
|
||||
if isinstance(obj.rid, list):
|
||||
# If it's an array, add worker_id prefix to each element
|
||||
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
||||
else:
|
||||
# If it's a single value, add worker_id prefix
|
||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
|
||||
|
||||
assert isinstance(self, TokenizerWorker)
|
||||
self._attach_multi_http_worker_info(obj)
|
||||
|
||||
if self.enable_trace:
|
||||
self._trace_request_start(obj, created_time)
|
||||
@@ -728,6 +733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
obj.token_ids_logprob,
|
||||
obj.stream,
|
||||
rid=obj.rid,
|
||||
http_worker_ipc=obj.http_worker_ipc,
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_port=obj.bootstrap_port,
|
||||
bootstrap_room=obj.bootstrap_room,
|
||||
@@ -749,6 +755,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
sampling_params,
|
||||
rid=obj.rid,
|
||||
priority=obj.priority,
|
||||
http_worker_ipc=obj.http_worker_ipc,
|
||||
)
|
||||
|
||||
return tokenized_obj
|
||||
@@ -1109,8 +1116,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str]:
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
@@ -1349,12 +1354,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
)
|
||||
continue
|
||||
|
||||
origin_rid = rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(rid)
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": origin_rid,
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
"weight_version": self.server_args.weight_version,
|
||||
@@ -1708,9 +1710,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
origin_rid = recv_obj.rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(origin_rid)
|
||||
state.finished = True
|
||||
if recv_obj.finished_reason:
|
||||
out = {
|
||||
@@ -1723,7 +1722,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
out = {
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": origin_rid,
|
||||
"id": recv_obj.rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
|
||||
@@ -3006,10 +3006,6 @@ def lru_cache_frozenset(maxsize=128):
|
||||
return decorator
|
||||
|
||||
|
||||
def get_origin_rid(rid):
|
||||
return rid.split("_", 1)[1] if "_" in rid else rid
|
||||
|
||||
|
||||
def apply_module_patch(target_module, target_function, wrappers):
|
||||
original_module, original_function = parse_module_path(
|
||||
target_module, target_function, False
|
||||
|
||||
Reference in New Issue
Block a user