[1/2] Refactor multi-tokenizer manager (#10074)

This commit is contained in:
Liangsheng Yin
2025-09-07 19:13:34 +08:00
committed by GitHub
parent 067246830d
commit e719bb0e84
6 changed files with 421 additions and 485 deletions

View File

@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _init_tokenizer_manager(
server_args: ServerArgs, port_args: PortArgs
) -> TokenizerManager:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
return tokenizer_manager, template_manager
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
@@ -816,23 +834,15 @@ def _launch_subprocesses(
),
)
detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Initialize templates
template_manager = None
else:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
tokenizer_manager, template_manager = _init_tokenizer_manager(
server_args, port_args
)
# Wait for the model to finish loading
@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info

View File

@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
deserialize_data,
get_main_process_id,
read_from_shared_memory,
write_data_for_multi_tokenizer,
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args, server_args = deserialize_data(port_args_data, server_args_data)
scheduler_info = scheduler_info_data
port_args, server_args, scheduler_info = read_from_shared_memory(
f"multi_tokenizer_args_{main_pid}"
)
server_args: ServerArgs
# API key authentication is not supported in multi-tokenizer mode
assert (
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
)
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
# only metrics middleware is supported in multi-tokenizer mode
worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
@@ -1187,12 +1179,10 @@ def launch_server(
)
if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
else:
# Add api key authorization
@@ -1239,6 +1229,7 @@ def launch_server(
workers=server_args.tokenizer_worker_num,
)
else:
app.is_single_tokenizer_mode = True
uvicorn.run(
app,
host=server_args.host,
@@ -1249,10 +1240,8 @@ def launch_server(
)
finally:
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
multi_tokenizer_args_shm.unlink()
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
else:
warmup_thread.join()