Files
sglang/python/sglang/srt/managers/multi_tokenizer_mixin.py
2025-09-05 13:39:46 +08:00

634 lines
24 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import asyncio
import dataclasses
import json
import logging
import multiprocessing as multiprocessing
import os
import sys
import threading
from multiprocessing import shared_memory
from typing import Dict
import setproctitle
import zmq
import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
MultiTokenizerRegisterReq,
MultiTokenizerWrapper,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class MultiTokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def create_sockets_mapping(self):
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
def init_tokenizer_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
):
"""init tokenizer mapping from register request"""
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
return True
else:
return False
def register_tokenizer_ipc(self, recv_obj, worker_id):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
return self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
return False
def _handle_output_by_index(self, output, i):
"""NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
),
decode_ids=(
[output.decode_ids[i]] if len(output.decode_ids) > i else None
),
read_offsets=(
[output.read_offsets[i]] if len(output.read_offsets) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
skip_special_tokens=(
[output.skip_special_tokens[i]]
if len(output.skip_special_tokens) > i
else None
),
spaces_between_special_tokens=(
[output.spaces_between_special_tokens[i]]
if len(output.spaces_between_special_tokens) > i
else None
),
no_stop_trim=(
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
embeddings=(
[output.embeddings[i]] if len(output.embeddings) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]] if len(output.output_strs) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
else:
new_output = output
return new_output
def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
try:
socket.close()
except Exception as e:
logger.warning(f"Failed to close socket: {e}")
self.tokenizer_mapping.clear()
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
"""A router to receive requests from MultiTokenizerManager"""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
self.server_args = server_args
context = zmq.asyncio.Context(3)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
self.receive_from_worker = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
)
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._task = asyncio.run_coroutine_threadsafe(
self.router_worker_obj(), self._loop
)
# Start handle_loop simultaneously
self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop
)
self.init_disaggregation()
def _run_loop(self):
self._loop.run_forever()
async def router_worker_obj(self):
while True:
recv_obj = await self.receive_from_worker.recv_pyobj()
await self.send_to_scheduler.send_pyobj(recv_obj)
async def handle_loop(self):
# special reqs will recv from scheduler, need to route to right worker
self.create_sockets_mapping()
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
await self._distribute_result_to_workers(recv_obj)
async def _distribute_result_to_workers(self, recv_obj):
"""Distribute result to corresponding workers based on rid"""
if isinstance(recv_obj, MultiTokenizerWrapper):
worker_ids = [recv_obj.worker_id]
recv_obj = recv_obj.obj
else:
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
if len(worker_ids) == 0:
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
return
# Distribute result to each worker
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_recv_obj = self._handle_output_by_index(recv_obj, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
"""Multi Process Tokenizer Manager that tokenizes the text."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
setproctitle.setproctitle(
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
)
# prevent init prefill bootstrapserver again
disaggregation_mode = server_args.disaggregation_mode
server_args.disaggregation_mode = "null"
super().__init__(server_args, port_args)
self.worker_id = os.getpid()
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
# For PD disaggregtion
self.server_args.disaggregation_mode = disaggregation_mode
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Communicator
self.register_multi_tokenizer_communicator = _Communicator(
self.send_to_scheduler, 2
)
self._result_dispatcher._mapping.append(
(
MultiTokenizerRegisterReq,
self.register_multi_tokenizer_communicator.handle_recv,
)
)
async def register_to_main_tokenizer_manager(self):
"""Register this worker to the main TokenizerManager"""
# create a handle loop to receive messages from the main TokenizerManager
self.auto_create_handle_loop()
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
req.ipc_name = self.tokenizer_ipc_name
_Communicator.enable_multi_tokenizer = True
await self.register_multi_tokenizer_communicator(req)
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(
func.__self__, MultiTokenizerRouter
):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_data(port_args: dict, server_args: dict):
"""Deserialize data from shared dictionaries"""
return PortArgs(**port_args), ServerArgs(**server_args)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
size = len(serialized)
try:
# Try to open existing shared memory
shm = shared_memory.SharedMemory(name=name)
# If size is insufficient, close and recreate
if shm.size < size:
shm.close()
shm.unlink()
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
except FileNotFoundError:
# If not present, create new shared memory
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
shm.buf[:size] = serialized
return shm
def read_from_shared_memory(name: str) -> dict:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
"""Write args information to share memory for multi-tokenizer"""
# get main process ID
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm