[1/2] Refactor multi-tokenizer manager (#10074)

This commit is contained in:
Liangsheng Yin
2025-09-07 19:13:34 +08:00
committed by GitHub
parent 067246830d
commit e719bb0e84
6 changed files with 421 additions and 485 deletions

View File

@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _init_tokenizer_manager(
server_args: ServerArgs, port_args: PortArgs
) -> TokenizerManager:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
return tokenizer_manager, template_manager
def _launch_subprocesses( def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]: ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
@@ -816,23 +834,15 @@ def _launch_subprocesses(
), ),
) )
detoken_proc.start() detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router # Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args) tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Initialize templates
template_manager = None template_manager = None
else: else:
# Launch tokenizer process tokenizer_manager, template_manager = _init_tokenizer_manager(
tokenizer_manager = TokenizerManager(server_args, port_args) server_args, port_args
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
) )
# Wait for the model to finish loading # Wait for the model to finish loading
@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info return tokenizer_manager, template_manager, scheduler_info

View File

@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.multi_tokenizer_mixin import ( from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager, MultiTokenizerManager,
deserialize_data,
get_main_process_id, get_main_process_id,
read_from_shared_memory, read_from_shared_memory,
write_data_for_multi_tokenizer, write_data_for_multi_tokenizer,
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs: async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process""" """Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid() pid = os.getpid()
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger.info(f"current worker_id: {pid}, main processID: {main_pid}") logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory # Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}") port_args, server_args, scheduler_info = read_from_shared_memory(
server_args_data = read_from_shared_memory(f"server_args_{main_pid}") f"multi_tokenizer_args_{main_pid}"
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") )
port_args, server_args = deserialize_data(port_args_data, server_args_data) server_args: ServerArgs
scheduler_info = scheduler_info_data
# API key authentication is not supported in multi-tokenizer mode
assert (
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
port_args.tokenizer_ipc_name = ( port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None) if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
if server_args is None:
# Initialize multi-tokenizer support for worker processes # Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer() fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics # only metrics middleware is supported in multi-tokenizer mode
) worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread( fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup, target=_wait_and_warmup,
args=( args=(
@@ -1187,12 +1179,10 @@ def launch_server(
) )
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = ( multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
write_data_for_multi_tokenizer( port_args,
port_args, server_args,
server_args, scheduler_info,
scheduler_info,
)
) )
else: else:
# Add api key authorization # Add api key authorization
@@ -1239,6 +1229,7 @@ def launch_server(
workers=server_args.tokenizer_worker_num, workers=server_args.tokenizer_worker_num,
) )
else: else:
app.is_single_tokenizer_mode = True
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
@@ -1249,10 +1240,8 @@ def launch_server(
) )
finally: finally:
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink() multi_tokenizer_args_shm.unlink()
server_args_shm.unlink() _global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else: else:
warmup_thread.join() warmup_thread.join()

View File

@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
@@ -69,7 +69,7 @@ class DecodeStatus:
sent_offset: int = 0 sent_offset: int = 0
class DetokenizerManager(MultiTokenizerMixin): class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
@@ -289,11 +289,11 @@ def run_detokenizer_process(
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop() manager.multi_http_worker_event_loop()
else: else:
manager.event_loop() manager.event_loop()
except Exception: except Exception:
manager.clear_tokenizer_mapping() manager.socket_mapping.clear_all_sockets()
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}") logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)

View File

@@ -0,0 +1,46 @@
"""Start bootstrap/kv-store-related server"""
import os
from typing import Type
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.server_args import ServerArgs
def start_disagg_service(
server_args: ServerArgs,
):
# Start kv boostrap server on prefill
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
if disagg_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
return bootstrap_server

View File

@@ -13,21 +13,21 @@
# ============================================================================== # ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import asyncio import asyncio
import dataclasses
import json
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
import os import os
import pickle
import sys import sys
import threading import threading
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Dict from typing import Any, Dict
import setproctitle import setproctitle
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalOut, BatchMultimodalOut,
@@ -44,302 +44,296 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MultiTokenizerMixin: class SocketMapping:
"""Mixin class for MultiTokenizerManager and DetokenizerManager""" def __init__(self):
self._zmq_context = zmq.Context()
self._mapping: Dict[str, zmq.Socket] = {}
def create_sockets_mapping(self): def clear_all_sockets(self):
if not hasattr(self, "tokenizer_mapping"): for socket in self._mapping.values():
self.tokenizer_mapping = {} socket.close()
# Create ZMQ context if needed self._mapping.clear()
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
def init_tokenizer_mapping( def register_ipc_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
): ):
"""init tokenizer mapping from register request""" type_str = "tokenizer" if is_tokenizer else "detokenizer"
ipc_name = recv_obj.ipc_name if worker_id in self._mapping:
worker_id_int = int(worker_id) logger.warning(
f"{type_str} already registered with worker {worker_id}, skipping..."
)
return
logger.info(
f"{type_str} not registered with worker {worker_id}, registering..."
)
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
self._mapping[worker_id] = socket
self._mapping[worker_id].send_pyobj(recv_obj)
if worker_id_int not in self.tokenizer_mapping: def send_output(self, worker_id: str, output: Any):
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) if worker_id not in self._mapping:
self.tokenizer_mapping[worker_id_int] = socket logger.error(
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj) f"worker ID {worker_id} not registered. Check if the server Process is alive"
return True )
else: return
return False self._mapping[worker_id].send_pyobj(output)
def register_tokenizer_ipc(self, recv_obj, worker_id):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
return self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
return False
def _handle_output_by_index(self, output, i): def _handle_output_by_index(output, i):
"""NOTE: A maintainable method is better here.""" """NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut): if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut( new_output = BatchTokenIDOut(
rids=[output.rids[i]], rids=[output.rids[i]],
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
if len(output.finished_reasons) > i if len(output.finished_reasons) > i
else None else None
), ),
decoded_texts=( decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
), ),
decode_ids=( decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
[output.decode_ids[i]] if len(output.decode_ids) > i else None read_offsets=(
), [output.read_offsets[i]] if len(output.read_offsets) > i else None
read_offsets=( ),
[output.read_offsets[i]] if len(output.read_offsets) > i else None output_ids=(
), [output.output_ids[i]]
output_ids=( if output.output_ids and len(output.output_ids) > i
[output.output_ids[i]] else None
if output.output_ids and len(output.output_ids) > i ),
else None skip_special_tokens=(
), [output.skip_special_tokens[i]]
skip_special_tokens=( if len(output.skip_special_tokens) > i
[output.skip_special_tokens[i]] else None
if len(output.skip_special_tokens) > i ),
else None spaces_between_special_tokens=(
), [output.spaces_between_special_tokens[i]]
spaces_between_special_tokens=( if len(output.spaces_between_special_tokens) > i
[output.spaces_between_special_tokens[i]] else None
if len(output.spaces_between_special_tokens) > i ),
else None no_stop_trim=(
), [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
no_stop_trim=( ),
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None prompt_tokens=(
), [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
prompt_tokens=( ),
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None completion_tokens=(
), [output.completion_tokens[i]]
completion_tokens=( if len(output.completion_tokens) > i
[output.completion_tokens[i]] else None
if len(output.completion_tokens) > i ),
else None cached_tokens=(
), [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
cached_tokens=( ),
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None spec_verify_ct=(
), [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
spec_verify_ct=( ),
[output.spec_verify_ct[i]] input_token_logprobs_val=(
if len(output.spec_verify_ct) > i [output.input_token_logprobs_val[i]]
else None if output.input_token_logprobs_val
), else None
input_token_logprobs_val=( ),
[output.input_token_logprobs_val[i]] input_token_logprobs_idx=(
if output.input_token_logprobs_val [output.input_token_logprobs_idx[i]]
else None if output.input_token_logprobs_idx
), else None
input_token_logprobs_idx=( ),
[output.input_token_logprobs_idx[i]] output_token_logprobs_val=(
if output.input_token_logprobs_idx [output.output_token_logprobs_val[i]]
else None if output.output_token_logprobs_val
), else None
output_token_logprobs_val=( ),
[output.output_token_logprobs_val[i]] output_token_logprobs_idx=(
if output.output_token_logprobs_val [output.output_token_logprobs_idx[i]]
else None if output.output_token_logprobs_idx
), else None
output_token_logprobs_idx=( ),
[output.output_token_logprobs_idx[i]] input_top_logprobs_val=(
if output.output_token_logprobs_idx [output.input_top_logprobs_val[i]]
else None if output.input_top_logprobs_val
), else None
input_top_logprobs_val=( ),
[output.input_top_logprobs_val[i]] input_top_logprobs_idx=(
if output.input_top_logprobs_val [output.input_top_logprobs_idx[i]]
else None if output.input_top_logprobs_idx
), else None
input_top_logprobs_idx=( ),
[output.input_top_logprobs_idx[i]] output_top_logprobs_val=(
if output.input_top_logprobs_idx [output.output_top_logprobs_val[i]]
else None if output.output_top_logprobs_val
), else None
output_top_logprobs_val=( ),
[output.output_top_logprobs_val[i]] output_top_logprobs_idx=(
if output.output_top_logprobs_val [output.output_top_logprobs_idx[i]]
else None if output.output_top_logprobs_idx
), else None
output_top_logprobs_idx=( ),
[output.output_top_logprobs_idx[i]] input_token_ids_logprobs_val=(
if output.output_top_logprobs_idx [output.input_token_ids_logprobs_val[i]]
else None if output.input_token_ids_logprobs_val
), else None
input_token_ids_logprobs_val=( ),
[output.input_token_ids_logprobs_val[i]] input_token_ids_logprobs_idx=(
if output.input_token_ids_logprobs_val [output.input_token_ids_logprobs_idx[i]]
else None if output.input_token_ids_logprobs_idx
), else None
input_token_ids_logprobs_idx=( ),
[output.input_token_ids_logprobs_idx[i]] output_token_ids_logprobs_val=(
if output.input_token_ids_logprobs_idx [output.output_token_ids_logprobs_val[i]]
else None if output.output_token_ids_logprobs_val
), else None
output_token_ids_logprobs_val=( ),
[output.output_token_ids_logprobs_val[i]] output_token_ids_logprobs_idx=(
if output.output_token_ids_logprobs_val [output.output_token_ids_logprobs_idx[i]]
else None if output.output_token_ids_logprobs_idx
), else None
output_token_ids_logprobs_idx=( ),
[output.output_token_ids_logprobs_idx[i]] output_hidden_states=(
if output.output_token_ids_logprobs_idx [output.output_hidden_states[i]]
else None if output.output_hidden_states
), else None
output_hidden_states=( ),
[output.output_hidden_states[i]] )
if output.output_hidden_states elif isinstance(output, BatchEmbeddingOut):
else None new_output = BatchEmbeddingOut(
), rids=[output.rids[i]],
) finished_reasons=(
elif isinstance(output, BatchEmbeddingOut): [output.finished_reasons[i]]
new_output = BatchEmbeddingOut( if len(output.finished_reasons) > i
rids=[output.rids[i]], else None
finished_reasons=( ),
[output.finished_reasons[i]] embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
if len(output.finished_reasons) > i prompt_tokens=(
else None [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
), ),
embeddings=( cached_tokens=(
[output.embeddings[i]] if len(output.embeddings) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
prompt_tokens=( )
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None elif isinstance(output, BatchStrOut):
), new_output = BatchStrOut(
cached_tokens=( rids=[output.rids[i]],
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None finished_reasons=(
), [output.finished_reasons[i]]
) if len(output.finished_reasons) > i
elif isinstance(output, BatchStrOut): else None
new_output = BatchStrOut( ),
rids=[output.rids[i]], output_strs=(
finished_reasons=( [output.output_strs[i]] if len(output.output_strs) > i else None
[output.finished_reasons[i]] ),
if len(output.finished_reasons) > i output_ids=(
else None [output.output_ids[i]]
), if output.output_ids and len(output.output_ids) > i
output_strs=( else None
[output.output_strs[i]] if len(output.output_strs) > i else None ),
), prompt_tokens=(
output_ids=( [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
[output.output_ids[i]] ),
if output.output_ids and len(output.output_ids) > i completion_tokens=(
else None [output.completion_tokens[i]]
), if len(output.completion_tokens) > i
prompt_tokens=( else None
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None ),
), cached_tokens=(
completion_tokens=( [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
[output.completion_tokens[i]] ),
if len(output.completion_tokens) > i spec_verify_ct=(
else None [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
), ),
cached_tokens=( input_token_logprobs_val=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.input_token_logprobs_val[i]]
), if output.input_token_logprobs_val
spec_verify_ct=( else None
[output.spec_verify_ct[i]] ),
if len(output.spec_verify_ct) > i input_token_logprobs_idx=(
else None [output.input_token_logprobs_idx[i]]
), if output.input_token_logprobs_idx
input_token_logprobs_val=( else None
[output.input_token_logprobs_val[i]] ),
if output.input_token_logprobs_val output_token_logprobs_val=(
else None [output.output_token_logprobs_val[i]]
), if output.output_token_logprobs_val
input_token_logprobs_idx=( else None
[output.input_token_logprobs_idx[i]] ),
if output.input_token_logprobs_idx output_token_logprobs_idx=(
else None [output.output_token_logprobs_idx[i]]
), if output.output_token_logprobs_idx
output_token_logprobs_val=( else None
[output.output_token_logprobs_val[i]] ),
if output.output_token_logprobs_val input_top_logprobs_val=(
else None [output.input_top_logprobs_val[i]]
), if output.input_top_logprobs_val
output_token_logprobs_idx=( else None
[output.output_token_logprobs_idx[i]] ),
if output.output_token_logprobs_idx input_top_logprobs_idx=(
else None [output.input_top_logprobs_idx[i]]
), if output.input_top_logprobs_idx
input_top_logprobs_val=( else None
[output.input_top_logprobs_val[i]] ),
if output.input_top_logprobs_val output_top_logprobs_val=(
else None [output.output_top_logprobs_val[i]]
), if output.output_top_logprobs_val
input_top_logprobs_idx=( else None
[output.input_top_logprobs_idx[i]] ),
if output.input_top_logprobs_idx output_top_logprobs_idx=(
else None [output.output_top_logprobs_idx[i]]
), if output.output_top_logprobs_idx
output_top_logprobs_val=( else None
[output.output_top_logprobs_val[i]] ),
if output.output_top_logprobs_val input_token_ids_logprobs_val=(
else None [output.input_token_ids_logprobs_val[i]]
), if output.input_token_ids_logprobs_val
output_top_logprobs_idx=( else None
[output.output_top_logprobs_idx[i]] ),
if output.output_top_logprobs_idx input_token_ids_logprobs_idx=(
else None [output.input_token_ids_logprobs_idx[i]]
), if output.input_token_ids_logprobs_idx
input_token_ids_logprobs_val=( else None
[output.input_token_ids_logprobs_val[i]] ),
if output.input_token_ids_logprobs_val output_token_ids_logprobs_val=(
else None [output.output_token_ids_logprobs_val[i]]
), if output.output_token_ids_logprobs_val
input_token_ids_logprobs_idx=( else None
[output.input_token_ids_logprobs_idx[i]] ),
if output.input_token_ids_logprobs_idx output_token_ids_logprobs_idx=(
else None [output.output_token_ids_logprobs_idx[i]]
), if output.output_token_ids_logprobs_idx
output_token_ids_logprobs_val=( else None
[output.output_token_ids_logprobs_val[i]] ),
if output.output_token_ids_logprobs_val output_hidden_states=(
else None [output.output_hidden_states[i]]
), if output.output_hidden_states
output_token_ids_logprobs_idx=( else None
[output.output_token_ids_logprobs_idx[i]] ),
if output.output_token_ids_logprobs_idx )
else None elif isinstance(output, BatchMultimodalOut):
), new_output = BatchMultimodalOut(
output_hidden_states=( rids=[output.rids[i]],
[output.output_hidden_states[i]] finished_reasons=(
if output.output_hidden_states [output.finished_reasons[i]]
else None if len(output.finished_reasons) > i
), else None
) ),
elif isinstance(output, BatchMultimodalOut): outputs=([output.outputs[i]] if len(output.outputs) > i else None),
new_output = BatchMultimodalOut( prompt_tokens=(
rids=[output.rids[i]], [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
finished_reasons=( ),
[output.finished_reasons[i]] completion_tokens=(
if len(output.finished_reasons) > i [output.completion_tokens[i]]
else None if len(output.completion_tokens) > i
), else None
outputs=([output.outputs[i]] if len(output.outputs) > i else None), ),
prompt_tokens=( cached_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
completion_tokens=( )
[output.completion_tokens[i]] else:
if len(output.completion_tokens) > i new_output = output
else None return new_output
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None class MultiHttpWorkerDetokenizerMixin:
), """Mixin class for MultiTokenizerManager and DetokenizerManager"""
)
else:
new_output = output
return new_output
def get_worker_ids_from_req_rids(self, rids): def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list): if isinstance(rids, list):
@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
worker_ids = [] worker_ids = []
return worker_ids return worker_ids
def multi_tokenizer_manager_event_loop(self): def multi_http_worker_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only""" """The event loop that handles requests, for multi multi-http-worker mode"""
self.create_sockets_mapping() self.socket_mapping = SocketMapping()
while True: while True:
recv_obj = self.recv_from_scheduler.recv_pyobj() recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj) output = self._request_dispatcher(recv_obj)
@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
# Send data using the corresponding socket # Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids): for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq): if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id): self.socket_mapping.register_ipc_mapping(
logger.info( recv_obj, worker_id, is_tokenizer=False
f"DetokenizerManager Created ZMQ socket for worker {worker_id}" )
)
continue
else: else:
if worker_id not in self.tokenizer_mapping: new_output = _handle_output_by_index(output, i)
logger.error( self.socket_mapping.send_output(worker_id, new_output)
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
try:
socket.close()
except Exception as e:
logger.warning(f"Failed to close socket: {e}")
self.tokenizer_mapping.clear()
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): class MultiTokenizerRouter:
"""A router to receive requests from MultiTokenizerManager""" """A router to receive requests from MultiTokenizerManager"""
def __init__( def __init__(
@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
self._handle_task = asyncio.run_coroutine_threadsafe( self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop print_exception_wrapper(self.handle_loop), self._loop
) )
self.init_disaggregation() self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
def _run_loop(self): def _run_loop(self):
self._loop.run_forever() self._loop.run_forever()
@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
async def handle_loop(self): async def handle_loop(self):
# special reqs will recv from scheduler, need to route to right worker # special reqs will recv from scheduler, need to route to right worker
self.create_sockets_mapping() self.socket_mapping = SocketMapping()
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
await self._distribute_result_to_workers(recv_obj) await self._distribute_result_to_workers(recv_obj)
@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
# Distribute result to each worker # Distribute result to each worker
for i, worker_id in enumerate(worker_ids): for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq): if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id): self.socket_mapping.register_ipc_mapping(
logger.info( recv_obj, worker_id, is_tokenizer=True
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}" )
)
continue
else: else:
if worker_id not in self.tokenizer_mapping: new_recv_obj = _handle_output_by_index(recv_obj, i)
logger.error( self.socket_mapping.send_output(worker_id, new_recv_obj)
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_recv_obj = self._handle_output_by_index(recv_obj, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin): class MultiTokenizerManager(TokenizerManager):
"""Multi Process Tokenizer Manager that tokenizes the text.""" """Multi Process Tokenizer Manager that tokenizes the text."""
def __init__( def __init__(
@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
sys.exit(1) sys.exit(1)
def serialize_port_args(port_args: PortArgs) -> dict: def get_main_process_id() -> int:
"""Serialize PortArgs into a shareable dictionary""" """Get the main process ID"""
return { return multiprocessing.current_process()._parent_pid
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_data(port_args: dict, server_args: dict): def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
"""Deserialize data from shared dictionaries"""
return PortArgs(**port_args), ServerArgs(**server_args)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory""" """Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8") serialized = pickle.dumps(obj)
size = len(serialized) size = len(serialized)
try: try:
# Try to open existing shared memory # Try to open existing shared memory
@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
return shm return shm
def read_from_shared_memory(name: str) -> dict: def read_from_shared_memory(name: str) -> Any:
"""Read data from shared memory""" """Read data from shared memory"""
try: try:
shm = shared_memory.SharedMemory(name=name) shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8")) data = pickle.loads(bytes(shm.buf))
shm.close() shm.close()
return data return data
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found") raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer( def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
): ):
@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
main_pid = get_main_process_id() main_pid = get_main_process_id()
current_pid = os.getpid() current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}") logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
args = (port_args, server_args, scheduler_info)
args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
args_shm.close()
# Write port_args to shared memory return args_shm
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm

View File

@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.base import BaseKVBootstrapServer from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_processor, get_processor,
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
@@ -321,8 +316,10 @@ class TokenizerManager:
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion self.disaggregation_mode = DisaggregationMode(
self.init_disaggregation() self.server_args.disaggregation_mode
)
self.bootstrap_server = start_disagg_service(self.server_args)
# For load balancing # For load balancing
self.current_load = 0 self.current_load = 0
@@ -471,38 +468,6 @@ class TokenizerManager:
] ]
) )
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=self.server_args.host,
port=self.server_args.disaggregation_bootstrap_port,
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],