From a9471542867ce938339db46098bdea7447f70562 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 8 Aug 2025 02:28:27 -0700 Subject: [PATCH] Revert "Support Multi Process Tokenizer Manager" (#8960) --- python/sglang/srt/entrypoints/http_server.py | 308 ++--------- .../srt/managers/detokenizer_manager.py | 198 +------ python/sglang/srt/managers/io_struct.py | 56 +- python/sglang/srt/managers/scheduler.py | 9 +- .../sglang/srt/managers/tokenizer_manager.py | 505 +----------------- python/sglang/srt/server_args.py | 15 - python/sglang/srt/utils.py | 14 - test/srt/run_suite.py | 1 - test/srt/test_multi_tokenizer.py | 100 ---- 9 files changed, 73 insertions(+), 1133 deletions(-) delete mode 100644 test/srt/test_multi_tokenizer.py diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index f60beedfe..c4d36088f 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -18,18 +18,14 @@ This file implements HTTP APIs for the inference engine via fastapi. """ import asyncio -import ctypes import dataclasses import json import logging import multiprocessing as multiprocessing import os -import sys -import tempfile import threading import time from http import HTTPStatus -from multiprocessing import Lock, Manager, Value, shared_memory from typing import AsyncIterator, Callable, Dict, Optional # Fix a bug of Python threading @@ -98,7 +94,7 @@ from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.reasoning_parser import ReasoningParser -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, add_prometheus_middleware, @@ -133,165 +129,8 @@ def set_global_state(global_state: _GlobalState): _global_state = global_state -def serialize_port_args(port_args: PortArgs) -> dict: - """Serialize PortArgs into a shareable dictionary""" - return { - "tokenizer_ipc_name": port_args.tokenizer_ipc_name, - "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name, - "detokenizer_ipc_name": port_args.detokenizer_ipc_name, - "nccl_port": port_args.nccl_port, - "rpc_ipc_name": port_args.rpc_ipc_name, - "metrics_ipc_name": port_args.metrics_ipc_name, - "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name, - } - - -def deserialize_port_args(data: dict) -> PortArgs: - """Deserialize PortArgs from a shared dictionary""" - return PortArgs(**data) - - -def serialize_server_args(server_args: ServerArgs) -> dict: - """Serialize ServerArgs into a shareable dictionary""" - return dataclasses.asdict(server_args) - - -def deserialize_server_args(data: dict) -> ServerArgs: - """Deserialize ServerArgs from a shared dictionary""" - return ServerArgs(**data) - - -def serialize_scheduler_info(scheduler_info: Dict) -> dict: - """Serialize scheduler_info into a shareable dictionary""" - return scheduler_info - - -def deserialize_scheduler_info(data: dict) -> Dict: - """Deserialize scheduler_info from a shared dictionary""" - return data - - -def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: - """Write data to shared memory""" - serialized = json.dumps(data).encode("utf-8") - size = len(serialized) - try: - # Try to open existing shared memory - shm = shared_memory.SharedMemory(name=name) - # If size is insufficient, close and recreate - if shm.size < size: - shm.close() - shm.unlink() - shm = shared_memory.SharedMemory(create=True, size=size, name=name) - except FileNotFoundError: - # If not present, create new shared memory - shm = shared_memory.SharedMemory(create=True, size=size, name=name) - - shm.buf[:size] = serialized - return shm - - -def read_from_shared_memory(name: str) -> dict: - """Read data from shared memory""" - try: - shm = shared_memory.SharedMemory(name=name) - data = json.loads(bytes(shm.buf).decode("utf-8")) - shm.close() - return data - except FileNotFoundError: - raise FileNotFoundError(f"Shared memory {name} not found") - - -def get_main_process_id() -> int: - """Get the main process ID""" - return multiprocessing.current_process()._parent_pid - - -def write_data_for_multi_tokenizer( - port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict -): - """Write args information to share memory for multi-tokenizer""" - # get main process ID - main_pid = get_main_process_id() - current_pid = os.getpid() - logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}") - - # Write port_args to shared memory - port_args_shm = write_to_shared_memory( - serialize_port_args(port_args), f"port_args_{current_pid}" - ) - # Write server_args to shared memory - server_args_shm = write_to_shared_memory( - serialize_server_args(server_args), f"server_args_{current_pid}" - ) - # Write scheduler_info to shared memory - scheduler_info_shm = write_to_shared_memory( - serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}" - ) - - port_args_shm.close() - server_args_shm.close() - scheduler_info_shm.close() - - return port_args_shm, server_args_shm, scheduler_info_shm - - -def init_multi_tokenizer() -> ServerArgs: - """Read args information from shm and init tokenizer manager for current process""" - pid = os.getpid() - main_pid = get_main_process_id() - logger.info(f"current worker_id: {pid}, main processID: {main_pid}") - - # Read port_args, server_args, and scheduler_info from shared memory - port_args_data = read_from_shared_memory(f"port_args_{main_pid}") - server_args_data = read_from_shared_memory(f"server_args_{main_pid}") - scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") - port_args = deserialize_port_args(port_args_data) - server_args = deserialize_server_args(server_args_data) - scheduler_info = deserialize_scheduler_info(scheduler_info_data) - - port_args.tokenizer_ipc_name = ( - f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - ) - - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args, False) - template_manager = TemplateManager() - template_manager.initialize_templates( - tokenizer_manager=tokenizer_manager, - model_path=server_args.model_path, - chat_template=server_args.chat_template, - completion_template=server_args.completion_template, - ) - # register multi tokenizer - tokenizer_manager.register_to_main_tokenizer_manager() - - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - set_global_state( - _GlobalState( - tokenizer_manager=tokenizer_manager, - template_manager=template_manager, - scheduler_info=scheduler_info, - ) - ) - return server_args - - @asynccontextmanager async def lifespan(fast_api_app: FastAPI): - server_args = getattr(fast_api_app, "server_args", None) - if server_args is None: - # for multi-tokenizer - fast_api_app.server_args = init_multi_tokenizer() - fast_api_app.warmup_thread = threading.Thread( - target=_wait_and_warmup, - args=( - fast_api_app.server_args, - None, # pipe_finish_writer not needed in worker - None, # launch_callback not needed in worker - ), - ) - # Initialize OpenAI serving handlers fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( _global_state.tokenizer_manager, _global_state.template_manager @@ -352,15 +191,7 @@ async def lifespan(fast_api_app: FastAPI): warmup_thread = getattr(fast_api_app, "warmup_thread", None) if warmup_thread is not None: warmup_thread.start() - - try: - yield - finally: - if server_args.tokenizer_worker_num > 1: - pid = os.getpid() - logger.info(f"uvicorn worker {pid} ending...") - warmup_thread.join() - logger.info(f"uvicorn {pid} ended") + yield # Fast API @@ -377,30 +208,6 @@ app.add_middleware( ) -# Function to setup all middlewares for multi-process compatibility -def setup_middlewares(): - """Setup all middlewares for both single and multi-process modes""" - worker_pid = os.getpid() - - # Setup API key middleware - api_key = os.environ.get("SGLANG_API_KEY", "") - if api_key: - add_api_key_middleware(app, api_key) - logger.info(f"Worker {worker_pid} added API key middleware") - - # Setup prometheus middleware - # Check if metrics are enabled via environment variable - enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false") - if enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - logger.info(f"Worker {worker_pid} added prometheus middleware") - - -# Call setup function at module level for multi-process compatibility -setup_middlewares() - - @app.exception_handler(HTTPException) async def validation_exception_handler(request: Request, exc: HTTPException): """Enrich HTTP exception with status code and other details""" @@ -1186,19 +993,9 @@ def launch_server( 1. The HTTP server, Engine, and TokenizerManager both run in the main process. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ - if server_args.tokenizer_worker_num > 1: - port_args = PortArgs.init_new(server_args) - port_args.tokenizer_worker_ipc_name = ( - f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - ) - tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( - server_args=server_args, port_args=port_args - ) - else: - tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( - server_args=server_args, - ) - + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, @@ -1207,83 +1004,42 @@ def launch_server( ) ) - if server_args.tokenizer_worker_num > 1: - # Set environment variables for middlewares in main process - if server_args.api_key: - os.environ["SGLANG_API_KEY"] = server_args.api_key - logger.info("Main process set SGLANG_API_KEY") + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) - if server_args.enable_metrics: - os.environ["SGLANG_ENABLE_METRICS"] = "true" - logger.info("Main process set SGLANG_ENABLE_METRICS=true") + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() - port_args_shm, server_args_shm, scheduler_info_shm = ( - write_data_for_multi_tokenizer( - port_args, - server_args, - scheduler_info, - ) - ) - else: - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Send a warmup request - we will create the thread launch it - # in the lifespan after all other warmups have fired. - warmup_thread = threading.Thread( - target=_wait_and_warmup, - args=( - server_args, - pipe_finish_writer, - launch_callback, - ), - ) - app.warmup_thread = warmup_thread + # Send a warmup request - we will create the thread launch it + # in the lifespan after all other warmups have fired. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + launch_callback, + ), + ) + app.warmup_thread = warmup_thread try: # Update logging configs set_uvicorn_logging_configs() app.server_args = server_args # Listen for HTTP requests - if server_args.tokenizer_worker_num > 1: - from uvicorn.config import LOGGING_CONFIG - - LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - } - uvicorn.run( - "sglang.srt.entrypoints.http_server:app", - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - workers=server_args.tokenizer_worker_num, - ) - else: - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) finally: - if server_args.tokenizer_worker_num > 1: - port_args_shm.unlink() - server_args_shm.unlink() - scheduler_info_shm.unlink() - else: - warmup_thread.join() + warmup_thread.join() def _execute_server_warmup( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 2a626ca85..29757b4b2 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -31,12 +31,10 @@ from sglang.srt.managers.io_struct import ( BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, - MultiTokenizerRegisterReq, ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, - get_workerids_from_rids, get_zmq_socket, kill_itself_when_parent_died, ) @@ -83,6 +81,7 @@ class DetokenizerManager: self.send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) + if server_args.skip_tokenizer_init: self.tokenizer = None else: @@ -95,208 +94,21 @@ class DetokenizerManager: self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) self.is_dummy = server_args.load_format == "dummy" - self.tokenizer_worker_num = server_args.tokenizer_worker_num + self._request_dispatcher = TypeBasedDispatcher( [ (BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchTokenIDOut, self.handle_batch_token_id_out), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), - (MultiTokenizerRegisterReq, lambda x: None), ] ) def event_loop(self): """The event loop that handles requests""" - while True: - try: - recv_obj = self.recv_from_scheduler.recv_pyobj() - output = self._request_dispatcher(recv_obj) - if self.tokenizer_worker_num <= 1: - self.send_to_tokenizer.send_pyobj(output) - else: - # Extract worker_id from rid - if isinstance(recv_obj.rids, list): - worker_ids = get_workerids_from_rids(recv_obj.rids) - else: - raise RuntimeError( - f"tokenizer_worker_num > 1, recv_obj.rids must be list" - ) - - if not hasattr(self, "tokenizer_mapping"): - self.tokenizer_mapping = {} - - # Create ZMQ context if needed - if not hasattr(self, "_zmq_context"): - self._zmq_context = zmq.Context() - - # Send data using the corresponding socket - for i, worker_id in enumerate(worker_ids): - if worker_id not in self.tokenizer_mapping: - # register the worker if not already done - if isinstance(recv_obj, MultiTokenizerRegisterReq): - self.init_tokenizer_mapping(recv_obj, worker_id) - else: - logger.error( - f"Worker {worker_id} not registered and not found in tokenizer mapping . " - "Please ensure the worker is registered correctly." - ) - continue - else: - if isinstance(recv_obj, MultiTokenizerRegisterReq): - continue - - # Create a new output object based on the type - if isinstance(output, BatchEmbeddingOut): - new_output = BatchEmbeddingOut( - rids=[output.rids[i]], - finished_reasons=[output.finished_reasons[i]], - embeddings=[output.embeddings[i]], - prompt_tokens=[output.prompt_tokens[i]], - cached_tokens=[output.cached_tokens[i]], - ) - elif isinstance(output, BatchStrOut): - new_output = BatchStrOut( - rids=[output.rids[i]], - finished_reasons=( - [output.finished_reasons[i]] - if len(output.finished_reasons) > i - else None - ), - output_strs=( - [output.output_strs[i]] - if len(output.output_strs) > i - else None - ), - output_ids=( - [output.output_ids[i]] - if output.output_ids and len(output.output_ids) > i - else None - ), - prompt_tokens=( - [output.prompt_tokens[i]] - if len(output.prompt_tokens) > i - else None - ), - completion_tokens=( - [output.completion_tokens[i]] - if len(output.completion_tokens) > i - else None - ), - cached_tokens=( - [output.cached_tokens[i]] - if len(output.cached_tokens) > i - else None - ), - spec_verify_ct=( - [output.spec_verify_ct[i]] - if len(output.spec_verify_ct) > i - else None - ), - input_token_logprobs_val=( - [output.input_token_logprobs_val[i]] - if output.input_token_logprobs_val - else None - ), - input_token_logprobs_idx=( - [output.input_token_logprobs_idx[i]] - if output.input_token_logprobs_idx - else None - ), - output_token_logprobs_val=( - [output.output_token_logprobs_val[i]] - if output.output_token_logprobs_val - else None - ), - output_token_logprobs_idx=( - [output.output_token_logprobs_idx[i]] - if output.output_token_logprobs_idx - else None - ), - input_top_logprobs_val=( - [output.input_top_logprobs_val[i]] - if output.input_top_logprobs_val - else None - ), - input_top_logprobs_idx=( - [output.input_top_logprobs_idx[i]] - if output.input_top_logprobs_idx - else None - ), - output_top_logprobs_val=( - [output.output_top_logprobs_val[i]] - if output.output_top_logprobs_val - else None - ), - output_top_logprobs_idx=( - [output.output_top_logprobs_idx[i]] - if output.output_top_logprobs_idx - else None - ), - input_token_ids_logprobs_val=( - [output.input_token_ids_logprobs_val[i]] - if output.input_token_ids_logprobs_val - else None - ), - input_token_ids_logprobs_idx=( - [output.input_token_ids_logprobs_idx[i]] - if output.input_token_ids_logprobs_idx - else None - ), - output_token_ids_logprobs_val=( - [output.output_token_ids_logprobs_val[i]] - if output.output_token_ids_logprobs_val - else None - ), - output_token_ids_logprobs_idx=( - [output.output_token_ids_logprobs_idx[i]] - if output.output_token_ids_logprobs_idx - else None - ), - output_hidden_states=( - [output.output_hidden_states[i]] - if output.output_hidden_states - else None - ), - ) - elif isinstance(output, BatchMultimodalOut): - new_output = BatchMultimodalOut( - rids=[output.rids[i]], - finished_reasons=[output.finished_reasons[i]], - prompt_tokens=[output.prompt_tokens[i]], - completion_tokens=[output.completion_tokens[i]], - cached_tokens=[output.cached_tokens[i]], - ) - else: - new_output = output - - try: - self.tokenizer_mapping[worker_id].send_pyobj(new_output) - except zmq.error.ZMQError as e: - logger.info( - f"ZMQ error when sending to worker {worker_id}: {e}" - ) - except Exception as e: - logger.error(f"Error in detokenizer event loop: {e}") - raise e - - def init_tokenizer_mapping( - self, recv_obj: MultiTokenizerRegisterReq, worker_id: str - ): - """init tokenizer mapping from register request""" - ipc_name = recv_obj.ipc_name - worker_id_int = int(worker_id) - - if worker_id_int not in self.tokenizer_mapping: - socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) - self.tokenizer_mapping[worker_id_int] = socket - logger.info( - f"Detokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}" - ) - else: - logger.info( - f"ZMQ socket for worker {worker_id} already exists, skipping creation" - ) + recv_obj = self.recv_from_scheduler.recv_pyobj() + output = self._request_dispatcher(recv_obj) + self.send_to_tokenizer.send_pyobj(output) def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a42254331..546128212 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -782,13 +782,12 @@ class BatchEmbeddingOut: @dataclass class FlushCacheReqInput: - rids: Optional[Union[List[str], str]] = None + pass @dataclass class FlushCacheReqOutput: success: bool - rids: Optional[Union[List[str], str]] = None @dataclass @@ -799,7 +798,6 @@ class UpdateWeightFromDiskReqInput: load_format: Optional[str] = None # Whether to abort all requests before updating weights abort_all_requests: bool = False - rids: Optional[Union[List[str], str]] = None @dataclass @@ -808,7 +806,6 @@ class UpdateWeightFromDiskReqOutput: message: str # Number of paused requests during weight sync. num_paused_requests: Optional[int] = 0 - rids: Optional[Union[List[str], str]] = None @dataclass @@ -822,14 +819,12 @@ class UpdateWeightsFromDistributedReqInput: flush_cache: bool = True # Whether to abort all requests before updating weights abort_all_requests: bool = False - rids: Optional[Union[List[str], str]] = None @dataclass class UpdateWeightsFromDistributedReqOutput: success: bool message: str - rids: Optional[Union[List[str], str]] = None @dataclass @@ -847,14 +842,12 @@ class UpdateWeightsFromTensorReqInput: flush_cache: bool = True # Whether to abort all requests before updating weights abort_all_requests: bool = False - rids: Optional[Union[List[str], str]] = None @dataclass class UpdateWeightsFromTensorReqOutput: success: bool message: str - rids: Optional[Union[List[str], str]] = None @dataclass @@ -871,27 +864,23 @@ class InitWeightsUpdateGroupReqInput: group_name: str = "weight_update_group" # The backend backend: str = "nccl" - rids: Optional[Union[List[str], str]] = None @dataclass class InitWeightsUpdateGroupReqOutput: success: bool message: str - rids: Optional[Union[List[str], str]] = None @dataclass class GetWeightsByNameReqInput: name: str truncate_size: int = 100 - rids: Optional[Union[List[str], str]] = None @dataclass class GetWeightsByNameReqOutput: parameter: list - rids: Optional[Union[List[str], str]] = None @dataclass @@ -899,12 +888,11 @@ class ReleaseMemoryOccupationReqInput: # Optional tags to identify the memory region, which is primarily used for RL # Currently we only support `weights` and `kv_cache` tags: Optional[List[str]] = None - rids: Optional[Union[List[str], str]] = None @dataclass class ReleaseMemoryOccupationReqOutput: - rids: Optional[Union[List[str], str]] = None + pass @dataclass @@ -912,23 +900,21 @@ class ResumeMemoryOccupationReqInput: # Optional tags to identify the memory region, which is primarily used for RL # Currently we only support `weights` and `kv_cache` tags: Optional[List[str]] = None - rids: Optional[Union[List[str], str]] = None @dataclass class ResumeMemoryOccupationReqOutput: - rids: Optional[Union[List[str], str]] = None + pass @dataclass class SlowDownReqInput: forward_sleep_time: Optional[float] - rids: Optional[Union[List[str], str]] = None @dataclass class SlowDownReqOutput: - rids: Optional[Union[List[str], str]] = None + pass @dataclass @@ -937,37 +923,29 @@ class AbortReq: rid: str = "" # Whether to abort all requests abort_all: bool = False - - rids: Optional[Union[List[str], str]] = None - + # The finished reason data finished_reason: Optional[Dict[str, Any]] = None - def __post_init__(self): - self.rids = self.rid - @dataclass class GetInternalStateReq: - rids: Optional[Union[List[str], str]] = None + pass @dataclass class GetInternalStateReqOutput: internal_state: Dict[Any, Any] - rids: Optional[Union[List[str], str]] = None @dataclass class SetInternalStateReq: server_args: Dict[str, Any] - rids: Optional[Union[List[str], str]] = None @dataclass class SetInternalStateReqOutput: updated: bool server_args: Dict[str, Any] - rids: Optional[Union[List[str], str]] = None @dataclass @@ -983,7 +961,6 @@ class ProfileReqInput: profile_by_stage: bool = False with_stack: Optional[bool] = None record_shapes: Optional[bool] = None - rids: Optional[Union[List[str], str]] = None class ProfileReqType(Enum): @@ -1002,14 +979,12 @@ class ProfileReq: with_stack: Optional[bool] = None record_shapes: Optional[bool] = None profile_id: Optional[str] = None - rids: Optional[Union[List[str], str]] = None @dataclass class ProfileReqOutput: success: bool message: str - rids: Optional[Union[List[str], str]] = None @dataclass @@ -1018,32 +993,27 @@ class ConfigureLoggingReq: log_requests_level: Optional[int] = None dump_requests_folder: Optional[str] = None dump_requests_threshold: Optional[int] = None - rids: Optional[Union[List[str], str]] = None @dataclass class OpenSessionReqInput: capacity_of_str_len: int session_id: Optional[str] = None - rids: Optional[Union[List[str], str]] = None @dataclass class CloseSessionReqInput: session_id: str - rids: Optional[Union[List[str], str]] = None @dataclass class OpenSessionReqOutput: session_id: Optional[str] success: bool - rids: Optional[Union[List[str], str]] = None @dataclass class HealthCheckOutput: - rids: Optional[Union[List[str], str]] = None pass @@ -1055,7 +1025,7 @@ class ExpertDistributionReq(Enum): @dataclass class ExpertDistributionReqOutput: - rids: Optional[Union[List[str], str]] = None + pass @dataclass @@ -1080,21 +1050,18 @@ class ParseFunctionCallReq: tool_call_parser: Optional[str] = ( None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. ) - rids: Optional[Union[List[str], str]] = None @dataclass class SeparateReasoningReqInput: text: str # The text to parse. reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1". - rids: Optional[Union[List[str], str]] = None @dataclass class VertexGenerateReqInput: instances: List[dict] parameters: Optional[dict] = None - rids: Optional[Union[List[str], str]] = None @dataclass @@ -1119,7 +1086,6 @@ class LoadLoRAAdapterReqInput: pinned: bool = False # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. lora_id: Optional[str] = None - rids: Optional[Union[List[str], str]] = None def to_ref(self) -> LoRARef: return LoRARef( @@ -1136,7 +1102,6 @@ class UnloadLoRAAdapterReqInput: lora_name: str # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. lora_id: Optional[str] = None - rids: Optional[Union[List[str], str]] = None def to_ref(self) -> LoRARef: return LoRARef( @@ -1150,18 +1115,11 @@ class LoRAUpdateResult: success: bool error_message: Optional[str] = None loaded_adapters: Optional[Dict[str, LoRARef]] = None - rids: Optional[Union[List[str], str]] = None LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult -@dataclass -class MultiTokenizerRegisterReq: - rids: Optional[Union[List[str], str]] = None - ipc_name: Optional[str] = None - - class BlockReqType(Enum): BLOCK = 1 UNBLOCK = 2 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5aef261c1..a97cca261 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -79,7 +79,6 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, - MultiTokenizerRegisterReq, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -248,6 +247,7 @@ class Scheduler( # Init inter-process communication context = zmq.Context(2) self.idle_sleeper = None + if self.pp_rank == 0 and self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False @@ -522,7 +522,6 @@ class Scheduler( (ExpertDistributionReq, self.expert_distribution_handle), (LoadLoRAAdapterReqInput, self.load_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter), - (MultiTokenizerRegisterReq, self.register_multi_tokenizer), ] ) @@ -1065,8 +1064,6 @@ class Scheduler( if self.recv_from_rpc is not None: self.recv_from_rpc.send_pyobj(output) else: - if recv_req.rids is not None: - output.rids = recv_req.rids self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( @@ -2407,10 +2404,6 @@ class Scheduler( result = self.tp_worker.unload_lora_adapter(recv_req) return result - def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq): - self.send_to_detokenizer.send_pyobj(recv_req) - return recv_req - def slow_down(self, recv_req: SlowDownReqInput): t = recv_req.forward_sleep_time if t is not None and t <= 0: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 700c290f5..50ac39f88 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -89,7 +89,6 @@ from sglang.srt.managers.io_struct import ( LoadLoRAAdapterReqInput, LoadLoRAAdapterReqOutput, LoRAUpdateResult, - MultiTokenizerRegisterReq, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -125,8 +124,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( dataclass_to_string_truncated, get_bool_env_var, - get_origin_rid, - get_workerids_from_rids, get_zmq_socket, kill_process_tree, ) @@ -174,9 +171,6 @@ class ReqState: output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) -_global_tokenizer_worker_num = 1 - - class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" @@ -184,7 +178,6 @@ class TokenizerManager: self, server_args: ServerArgs, port_args: PortArgs, - is_main: Optional[bool] = True, ): # Parse args self.server_args = server_args @@ -198,9 +191,6 @@ class TokenizerManager: ) self.crash_dump_folder = server_args.crash_dump_folder - self.is_main = is_main - self.worker_id = os.getpid() - # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name @@ -265,41 +255,13 @@ class TokenizerManager: ) # Init inter-process communication - context = zmq.asyncio.Context(3) + context = zmq.asyncio.Context(2) self.recv_from_detokenizer = get_zmq_socket( context, zmq.PULL, port_args.tokenizer_ipc_name, True ) - global _global_tokenizer_worker_num - _global_tokenizer_worker_num = server_args.tokenizer_worker_num - if server_args.tokenizer_worker_num > 1: - self.tokenizer_ipc_name = port_args.tokenizer_ipc_name - if self.is_main: - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name, True - ) - self.receive_from_worker = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True - ) - self._loop = asyncio.new_event_loop() - self._thread = threading.Thread(target=self._run_loop, daemon=True) - self._thread.start() - self._task = asyncio.run_coroutine_threadsafe( - self.router_worker_obj(), self._loop - ) - # Start handle_loop simultaneously - self._handle_task = asyncio.run_coroutine_threadsafe( - print_exception_wrapper(self.handle_loop), self._loop - ) - - else: - # actual send to main receiver_from_worker - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False - ) - else: - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name, True - ) + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True + ) # Request states self.no_create_loop = False @@ -353,27 +315,26 @@ class TokenizerManager: # Start kv boostrap server on prefill if self.disaggregation_mode == DisaggregationMode.PREFILL: # only start bootstrap server on prefill tm - if self.is_main: - kv_bootstrap_server_class = get_kv_class( - self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER - ) - self.bootstrap_server = kv_bootstrap_server_class( - self.server_args.disaggregation_bootstrap_port - ) - is_create_store = ( - self.server_args.node_rank == 0 - and self.server_args.disaggregation_transfer_backend == "ascend" - ) - if is_create_store: - try: - from mf_adapter import create_config_store + kv_bootstrap_server_class = get_kv_class( + self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER + ) + self.bootstrap_server = kv_bootstrap_server_class( + self.server_args.disaggregation_bootstrap_port + ) + is_create_store = ( + self.server_args.node_rank == 0 + and self.server_args.disaggregation_transfer_backend == "ascend" + ) + if is_create_store: + try: + from mf_adapter import create_config_store - ascend_url = os.getenv("ASCEND_MF_STORE_URL") - create_config_store(ascend_url) - except Exception as e: - error_message = f"Failed create mf store, invalid ascend_url." - error_message += f" With exception {e}" - raise error_message + ascend_url = os.getenv("ASCEND_MF_STORE_URL") + create_config_store(ascend_url) + except Exception as e: + error_message = f"Failed create mf store, invalid ascend_url." + error_message += f" With exception {e}" + raise error_message # For load balancing self.current_load = 0 @@ -506,14 +467,6 @@ class TokenizerManager: ] ) - def _run_loop(self): - self._loop.run_forever() - - async def router_worker_obj(self): - while True: - recv_obj = await self.receive_from_worker.recv_pyobj() - await self.send_to_scheduler.send_pyobj(recv_obj) - async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -526,15 +479,6 @@ class TokenizerManager: async with self._is_updating_cond: await self._is_updating_cond.wait_for(lambda: not self._is_updating) - if self.server_args.tokenizer_worker_num > 1: - # Modify rid, add worker_id - if isinstance(obj.rid, list): - # If it's an array, add worker_id prefix to each element - obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid] - else: - # If it's a single value, add worker_id prefix - obj.rid = f"{self.worker_id}_{obj.rid}" - if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( @@ -1561,377 +1505,11 @@ class TokenizerManager: async def handle_loop(self): """The event loop that handles requests""" + while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() - # In multi-worker mode, distribute results to corresponding workers - if self.server_args.tokenizer_worker_num > 1 and self.is_main: - await self._distribute_result_to_workers(recv_obj) - else: - # In single worker mode, process directly - self._result_dispatcher(recv_obj) - - self.last_receive_tstamp = time.time() - - def init_tokenizer_mapping(self, recv_obj: MultiTokenizerRegisterReq): - """init tokenizer mapping from register request""" - if isinstance(recv_obj.rids, list): - worker_ids = get_workerids_from_rids(recv_obj.rids) - else: - raise RuntimeError(f"tokenizer_worker_num > 1, recv_obj.rids must be list") - - for worker_id in worker_ids: - ipc_name = recv_obj.ipc_name - worker_id_int = int(worker_id) - - if worker_id_int not in self.tokenizer_mapping: - socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) - self.tokenizer_mapping[worker_id_int] = socket - logger.info( - f"Main Tokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}" - ) - else: - logger.info( - f"ZMQ socket for worker {worker_id} already exists, skipping creation" - ) - - async def _distribute_result_to_workers(self, recv_obj): - """Distribute result to corresponding workers based on rid""" - - worker_ids = get_workerids_from_rids(recv_obj.rids) - if len(worker_ids) == 0: self._result_dispatcher(recv_obj) - return - - if not hasattr(self, "tokenizer_mapping"): - self.tokenizer_mapping = {} - - # Create ZMQ context if needed - if not hasattr(self, "_zmq_context"): - self._zmq_context = zmq.Context() - - # Distribute result to each worker - for i, worker_id in enumerate(worker_ids): - if worker_id not in self.tokenizer_mapping: - if isinstance(recv_obj, MultiTokenizerRegisterReq): - self.init_tokenizer_mapping(recv_obj) - else: - logger.error( - f"Worker {worker_id} not registered and not found in tokenizer mapping . " - "Please ensure the worker is registered correctly." - ) - continue - else: - if isinstance(recv_obj, MultiTokenizerRegisterReq): - continue - - if not isinstance( - recv_obj, - ( - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - BatchMultimodalOut, - ), - ): - # Send to worker - self.tokenizer_mapping[worker_id].send_pyobj(recv_obj) - else: - if isinstance(recv_obj, BatchTokenIDOut): - new_recv_obj = BatchTokenIDOut( - [recv_obj.rids[i]], - ( - [recv_obj.finished_reasons[i]] - if len(recv_obj.finished_reasons) > i - else None - ), - ( - [recv_obj.decoded_texts[i]] - if len(recv_obj.decoded_texts) > i - else None - ), - ( - [recv_obj.decode_ids[i]] - if len(recv_obj.decode_ids) > i - else None - ), - ( - [recv_obj.read_offsets[i]] - if len(recv_obj.read_offsets) > i - else None - ), - ( - [recv_obj.output_ids[i]] - if recv_obj.output_ids and len(recv_obj.output_ids) > i - else None - ), - ( - [recv_obj.skip_special_tokens[i]] - if len(recv_obj.skip_special_tokens) > i - else None - ), - ( - [recv_obj.spaces_between_special_tokens[i]] - if len(recv_obj.spaces_between_special_tokens) > i - else None - ), - ( - [recv_obj.no_stop_trim[i]] - if len(recv_obj.no_stop_trim) > i - else None - ), - ( - [recv_obj.prompt_tokens[i]] - if len(recv_obj.prompt_tokens) > i - else None - ), - ( - [recv_obj.completion_tokens[i]] - if len(recv_obj.completion_tokens) > i - else None - ), - ( - [recv_obj.cached_tokens[i]] - if len(recv_obj.cached_tokens) > i - else None - ), - ( - [recv_obj.spec_verify_ct[i]] - if len(recv_obj.spec_verify_ct) > i - else None - ), - ( - [recv_obj.input_token_logprobs_val[i]] - if recv_obj.input_token_logprobs_val - else None - ), - ( - [recv_obj.input_token_logprobs_idx[i]] - if recv_obj.input_token_logprobs_idx - else None - ), - ( - [recv_obj.output_token_logprobs_val[i]] - if recv_obj.output_token_logprobs_val - else None - ), - ( - [recv_obj.output_token_logprobs_idx[i]] - if recv_obj.output_token_logprobs_idx - else None - ), - ( - [recv_obj.input_top_logprobs_val[i]] - if recv_obj.input_top_logprobs_val - else None - ), - ( - [recv_obj.input_top_logprobs_idx[i]] - if recv_obj.input_top_logprobs_idx - else None - ), - ( - [recv_obj.output_top_logprobs_val[i]] - if recv_obj.output_top_logprobs_val - else None - ), - ( - [recv_obj.output_top_logprobs_idx[i]] - if recv_obj.output_top_logprobs_idx - else None - ), - ( - [recv_obj.input_token_ids_logprobs_val[i]] - if recv_obj.input_token_ids_logprobs_val - else None - ), - ( - [recv_obj.input_token_ids_logprobs_idx[i]] - if recv_obj.input_token_ids_logprobs_idx - else None - ), - ( - [recv_obj.output_token_ids_logprobs_val[i]] - if recv_obj.output_token_ids_logprobs_val - else None - ), - ( - [recv_obj.output_token_ids_logprobs_idx[i]] - if recv_obj.output_token_ids_logprobs_idx - else None - ), - ( - [recv_obj.output_hidden_states[i]] - if recv_obj.output_hidden_states - else None - ), - ) - elif isinstance(recv_obj, BatchEmbeddingOut): - new_recv_obj = BatchEmbeddingOut( - [recv_obj.rids[i]], - ( - [recv_obj.finished_reasons[i]] - if len(recv_obj.finished_reasons) > i - else None - ), - ( - [recv_obj.embeddings[i]] - if len(recv_obj.embeddings) > i - else None - ), - ( - [recv_obj.prompt_tokens[i]] - if len(recv_obj.prompt_tokens) > i - else None - ), - ( - [recv_obj.cached_tokens[i]] - if len(recv_obj.cached_tokens) > i - else None - ), - ) - elif isinstance(recv_obj, BatchStrOut): - new_recv_obj = BatchStrOut( - [recv_obj.rids[i]], - ( - [recv_obj.finished_reasons[i]] - if len(recv_obj.finished_reasons) > i - else None - ), - ( - [recv_obj.output_strs[i]] - if len(recv_obj.output_strs) > i - else None - ), - ( - [recv_obj.output_ids[i]] - if recv_obj.output_ids and len(recv_obj.output_ids) > i - else None - ), - ( - [recv_obj.prompt_tokens[i]] - if len(recv_obj.prompt_tokens) > i - else None - ), - ( - [recv_obj.completion_tokens[i]] - if len(recv_obj.completion_tokens) > i - else None - ), - ( - [recv_obj.cached_tokens[i]] - if len(recv_obj.cached_tokens) > i - else None - ), - ( - [recv_obj.spec_verify_ct[i]] - if len(recv_obj.spec_verify_ct) > i - else None - ), - ( - [recv_obj.input_token_logprobs_val[i]] - if recv_obj.input_token_logprobs_val - else None - ), - ( - [recv_obj.input_token_logprobs_idx[i]] - if recv_obj.input_token_logprobs_idx - else None - ), - ( - [recv_obj.output_token_logprobs_val[i]] - if recv_obj.output_token_logprobs_val - else None - ), - ( - [recv_obj.output_token_logprobs_idx[i]] - if recv_obj.output_token_logprobs_idx - else None - ), - ( - [recv_obj.input_top_logprobs_val[i]] - if recv_obj.input_top_logprobs_val - else None - ), - ( - [recv_obj.input_top_logprobs_idx[i]] - if recv_obj.input_top_logprobs_idx - else None - ), - ( - [recv_obj.output_top_logprobs_val[i]] - if recv_obj.output_top_logprobs_val - else None - ), - ( - [recv_obj.output_top_logprobs_idx[i]] - if recv_obj.output_top_logprobs_idx - else None - ), - ( - [recv_obj.input_token_ids_logprobs_val[i]] - if recv_obj.input_token_ids_logprobs_val - else None - ), - ( - [recv_obj.input_token_ids_logprobs_idx[i]] - if recv_obj.input_token_ids_logprobs_idx - else None - ), - ( - [recv_obj.output_token_ids_logprobs_val[i]] - if recv_obj.output_token_ids_logprobs_val - else None - ), - ( - [recv_obj.output_token_ids_logprobs_idx[i]] - if recv_obj.output_token_ids_logprobs_idx - else None - ), - ( - [recv_obj.output_hidden_states[i]] - if recv_obj.output_hidden_states - else None - ), - ) - elif isinstance(recv_obj, BatchMultimodalOut): - new_recv_obj = BatchMultimodalOut( - [recv_obj.rids[i]], - ( - [recv_obj.finished_reasons[i]] - if len(recv_obj.finished_reasons) > i - else None - ), - ([recv_obj.outputs[i]] if len(recv_obj.outputs) > i else None), - ( - [recv_obj.prompt_tokens[i]] - if len(recv_obj.prompt_tokens) > i - else None - ), - ( - [recv_obj.completion_tokens[i]] - if len(recv_obj.completion_tokens) > i - else None - ), - ( - [recv_obj.cached_tokens[i]] - if len(recv_obj.cached_tokens) > i - else None - ), - ) - try: - self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj) - except zmq.ZMQError as e: - raise RuntimeError( - f"Failed to send result to worker {worker_id}: {e}" - ) from e - - def register_to_main_tokenizer_manager(self): - """Register this worker to the main TokenizerManager""" - req = MultiTokenizerRegisterReq() - req.rids = [f"{self.worker_id}_registertokenizer"] - req.ipc_name = self.tokenizer_ipc_name - self.send_to_scheduler.send_pyobj(req) - time.sleep(5) + self.last_receive_tstamp = time.time() def _handle_batch_output( self, @@ -1946,12 +1524,10 @@ class TokenizerManager: f"Received output for {rid=} but the state was deleted in TokenizerManager." ) continue - originRid = rid - if self.server_args.tokenizer_worker_num > 1: - originRid = get_origin_rid(rid) + # Build meta_info and return value meta_info = { - "id": originRid, + "id": rid, "finish_reason": recv_obj.finished_reasons[i], "prompt_tokens": recv_obj.prompt_tokens[i], } @@ -2252,9 +1828,6 @@ class TokenizerManager: if is_health_check_generate_req(recv_obj): return state = self.rid_to_state[recv_obj.rid] - rid = recv_obj.rid - if self.server_args.tokenizer_worker_num > 1: - rid = get_origin_rid(rid) state.finished = True if recv_obj.finished_reason: out = { @@ -2267,7 +1840,7 @@ class TokenizerManager: out = { "text": "", "meta_info": { - "id": rid, + "id": recv_obj.rid, "finish_reason": { "type": "abort", "message": "Abort before prefill", @@ -2456,7 +2029,6 @@ class _Communicator(Generic[T]): self._ready_queue: Deque[asyncio.Future] = deque() async def __call__(self, obj): - global _global_tokenizer_worker_num ready_event = asyncio.Event() if self._result_event is not None or len(self._ready_queue) > 0: self._ready_queue.append(ready_event) @@ -2465,14 +2037,6 @@ class _Communicator(Generic[T]): assert self._result_values is None if obj: - if _global_tokenizer_worker_num > 1: - if obj.rids is None: - obj.rids = f"{os.getpid()}_{uuid.uuid4().hex}_Communicator" - else: - if isinstance(obj.rids, str): - obj.rids = f"{os.getpid()}_{obj.rids}" - elif isinstance(obj.rids, list): - obj.rids = [f"{os.getpid()}_{rid}" for rid in obj.rids] self._sender.send_pyobj(obj) self._result_event = asyncio.Event() @@ -2487,19 +2051,6 @@ class _Communicator(Generic[T]): return result_values def handle_recv(self, recv_obj: T): - global _global_tokenizer_worker_num - if _global_tokenizer_worker_num > 1: - # If rids is a string and not empty, remove the prefix - if ( - hasattr(recv_obj, "rids") - and isinstance(recv_obj.rids, str) - and recv_obj.rids - ): - recv_obj.rids = get_origin_rid(recv_obj.rids) - # If rids is a list, remove prefix from each element - elif hasattr(recv_obj, "rids") and isinstance(recv_obj.rids, list): - recv_obj.rids = [get_origin_rid(rid) for rid in recv_obj.rids] - self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: self._result_event.set() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e6d2f9c57..7bfd443bf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -51,7 +51,6 @@ class ServerArgs: model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" - tokenizer_worker_num: int = 1 skip_tokenizer_init: bool = False load_format: str = "auto" model_loader_extra_config: str = "{}" @@ -732,12 +731,6 @@ class ServerArgs: default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) - parser.add_argument( - "--tokenizer-worker-num", - type=int, - default=ServerArgs.tokenizer_worker_num, - help="The worker num of the tokenizer manager.", - ) parser.add_argument( "--tokenizer-mode", type=str, @@ -2096,9 +2089,6 @@ class ServerArgs: self.chunked_prefill_size % self.page_size == 0 ), "chunked_prefill_size must be divisible by page_size" - # Check multi tokenizer - assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1" - def check_lora_server_args(self): assert ( self.max_loras_per_batch > 0 @@ -2264,9 +2254,6 @@ class PortArgs: # The ipc filename for Scheduler to send metrics metrics_ipc_name: str - # The ipc filename for Tokenizer and worker tokenizer - tokenizer_worker_ipc_name: Optional[str] - @staticmethod def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": if server_args.nccl_port is None: @@ -2290,7 +2277,6 @@ class PortArgs: nccl_port=nccl_port, rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", - tokenizer_worker_ipc_name=None, ) else: # DP attention. Use TCP + port to handle both single-node and multi-node. @@ -2324,7 +2310,6 @@ class PortArgs: nccl_port=nccl_port, rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}", - tokenizer_worker_ipc_name=None, ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 3a6fd7b42..edf441945 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2754,20 +2754,6 @@ def lru_cache_frozenset(maxsize=128): return decorator -def get_workerids_from_rids(rids): - if isinstance(rids, list): - worker_ids = [int(rid.split("_")[0]) for rid in rids] - elif isinstance(rids, str): - worker_ids = [int(rids.split("_")[0])] - else: - worker_ids = [] - return worker_ids - - -def get_origin_rid(rid): - return rid.split("_", 1)[1] if "_" in rid else rid - - def apply_module_patch(target_module, target_function, wrappers): original_module, original_function = parse_module_path( target_module, target_function, False diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c4546519a..aecea4498 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -78,7 +78,6 @@ suites = { TestFile("test_mla_int8_deepseek_v3.py", 429), TestFile("test_mla_flashinfer.py", 302), TestFile("test_mla_fp8.py", 93), - TestFile("test_multi_tokenizer.py", 200), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 234), TestFile("test_penalty.py", 41), diff --git a/test/srt/test_multi_tokenizer.py b/test/srt/test_multi_tokenizer.py deleted file mode 100644 index 53409d473..000000000 --- a/test/srt/test_multi_tokenizer.py +++ /dev/null @@ -1,100 +0,0 @@ -import inspect -import unittest -from dataclasses import fields, is_dataclass -from types import SimpleNamespace - -import sglang.srt.managers.io_struct as io_struct -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - auto_config_device, - get_benchmark_args, - is_in_ci, - popen_launch_server, - run_benchmark, - write_github_step_summary, -) - - -class TestMultiTokenizer(CustomTestCase): - # from test_hicache.py - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--tokenizer-worker-num", - 8, - "--mem-fraction-static", - 0.7, - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.65) - - def test_all_io_struct(self): - print("check all req types in io_struct.py") - result = [] - for name, obj in inspect.getmembers(io_struct): - if inspect.isclass(obj) and is_dataclass(obj): - field_names = [f.name for f in fields(obj)] - if "rids" in field_names or "rid" in field_names: - continue - result.append(name) - print(f"WARNING:Some Request types in io_struct.py have no rids: {result}") - print( - "If a special request type can't work, check the rids field which is needed for multi-tokenizer." - ) - - def test_multi_tokenizer_ttft(self): - # from test_bench_serving.py run_bench_serving - args = get_benchmark_args( - base_url=self.base_url, - dataset_name="random", - dataset_path="", - tokenizer=None, - num_prompts=100, - random_input_len=4096, - random_output_len=2048, - sharegpt_context_len=None, - request_rate=1, - disable_stream=False, - disable_ignore_eos=False, - seed=0, - device=auto_config_device(), - lora_name=None, - ) - res = run_benchmark(args) - if is_in_ci(): - write_github_step_summary( - f"### test_multi_tokenizer_ttft\n" - f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n" - ) - self.assertLess(res["median_e2e_latency_ms"], 11000) - self.assertLess(res["median_ttft_ms"], 86) - self.assertLess(res["median_itl_ms"], 10) - - -if __name__ == "__main__": - unittest.main()