Support Multi Process Tokenizer Manager (#6555)

Signed-off-by: ybyang <ybyang7@iflytek.com>
Signed-off-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: lw9527 <952799980@qq.com>
Co-authored-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
This commit is contained in:
ybyang
2025-08-08 16:45:50 +08:00
committed by GitHub
parent 6ee6619b7a
commit 7490e3f67d
9 changed files with 1133 additions and 73 deletions

View File

@@ -18,14 +18,18 @@ This file implements HTTP APIs for the inference engine via fastapi.
"""
import asyncio
import ctypes
import dataclasses
import json
import logging
import multiprocessing as multiprocessing
import os
import sys
import tempfile
import threading
import time
from http import HTTPStatus
from multiprocessing import Lock, Manager, Value, shared_memory
from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading
@@ -94,7 +98,7 @@ from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
add_prometheus_middleware,
@@ -129,8 +133,165 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_port_args(data: dict) -> PortArgs:
"""Deserialize PortArgs from a shared dictionary"""
return PortArgs(**data)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def deserialize_server_args(data: dict) -> ServerArgs:
"""Deserialize ServerArgs from a shared dictionary"""
return ServerArgs(**data)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
size = len(serialized)
try:
# Try to open existing shared memory
shm = shared_memory.SharedMemory(name=name)
# If size is insufficient, close and recreate
if shm.size < size:
shm.close()
shm.unlink()
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
except FileNotFoundError:
# If not present, create new shared memory
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
shm.buf[:size] = serialized
return shm
def read_from_shared_memory(name: str) -> dict:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
"""Write args information to share memory for multi-tokenizer"""
# get main process ID
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read port_args, server_args, and scheduler_info from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args = deserialize_port_args(port_args_data)
server_args = deserialize_server_args(server_args_data)
scheduler_info = deserialize_scheduler_info(scheduler_info_data)
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args, False)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# register multi tokenizer
tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# for multi-tokenizer
fast_api_app.server_args = init_multi_tokenizer()
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
@@ -191,7 +352,15 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn {pid} ended")
# Fast API
@@ -208,6 +377,30 @@ app.add_middleware(
)
# Function to setup all middlewares for multi-process compatibility
def setup_middlewares():
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
# Setup API key middleware
api_key = os.environ.get("SGLANG_API_KEY", "")
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
# Setup prometheus middleware
# Check if metrics are enabled via environment variable
enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
# Call setup function at module level for multi-process compatibility
setup_middlewares()
@app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details"""
@@ -993,9 +1186,19 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
@@ -1004,42 +1207,83 @@ def launch_server(
)
)
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
if server_args.tokenizer_worker_num > 1:
# Set environment variables for middlewares in main process
if server_args.api_key:
os.environ["SGLANG_API_KEY"] = server_args.api_key
logger.info("Main process set SGLANG_API_KEY")
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
if server_args.enable_metrics:
os.environ["SGLANG_ENABLE_METRICS"] = "true"
logger.info("Main process set SGLANG_ENABLE_METRICS=true")
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
)
else:
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
if server_args.tokenizer_worker_num > 1:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
warmup_thread.join()
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
else:
warmup_thread.join()
def _execute_server_warmup(