Support Multi Process Tokenizer Manager (#6555)
Signed-off-by: ybyang <ybyang7@iflytek.com> Signed-off-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: lw9527 <952799980@qq.com> Co-authored-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
This commit is contained in:
@@ -18,14 +18,18 @@ This file implements HTTP APIs for the inference engine via fastapi.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from multiprocessing import Lock, Manager, Value, shared_memory
|
||||
from typing import AsyncIterator, Callable, Dict, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
@@ -94,7 +98,7 @@ from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
@@ -129,8 +133,165 @@ def set_global_state(global_state: _GlobalState):
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
def serialize_port_args(port_args: PortArgs) -> dict:
|
||||
"""Serialize PortArgs into a shareable dictionary"""
|
||||
return {
|
||||
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
||||
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
||||
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
||||
"nccl_port": port_args.nccl_port,
|
||||
"rpc_ipc_name": port_args.rpc_ipc_name,
|
||||
"metrics_ipc_name": port_args.metrics_ipc_name,
|
||||
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
||||
}
|
||||
|
||||
|
||||
def deserialize_port_args(data: dict) -> PortArgs:
|
||||
"""Deserialize PortArgs from a shared dictionary"""
|
||||
return PortArgs(**data)
|
||||
|
||||
|
||||
def serialize_server_args(server_args: ServerArgs) -> dict:
|
||||
"""Serialize ServerArgs into a shareable dictionary"""
|
||||
return dataclasses.asdict(server_args)
|
||||
|
||||
|
||||
def deserialize_server_args(data: dict) -> ServerArgs:
|
||||
"""Deserialize ServerArgs from a shared dictionary"""
|
||||
return ServerArgs(**data)
|
||||
|
||||
|
||||
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
||||
"""Serialize scheduler_info into a shareable dictionary"""
|
||||
return scheduler_info
|
||||
|
||||
|
||||
def deserialize_scheduler_info(data: dict) -> Dict:
|
||||
"""Deserialize scheduler_info from a shared dictionary"""
|
||||
return data
|
||||
|
||||
|
||||
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
||||
"""Write data to shared memory"""
|
||||
serialized = json.dumps(data).encode("utf-8")
|
||||
size = len(serialized)
|
||||
try:
|
||||
# Try to open existing shared memory
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
# If size is insufficient, close and recreate
|
||||
if shm.size < size:
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
||||
except FileNotFoundError:
|
||||
# If not present, create new shared memory
|
||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
||||
|
||||
shm.buf[:size] = serialized
|
||||
return shm
|
||||
|
||||
|
||||
def read_from_shared_memory(name: str) -> dict:
|
||||
"""Read data from shared memory"""
|
||||
try:
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
||||
shm.close()
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Shared memory {name} not found")
|
||||
|
||||
|
||||
def get_main_process_id() -> int:
|
||||
"""Get the main process ID"""
|
||||
return multiprocessing.current_process()._parent_pid
|
||||
|
||||
|
||||
def write_data_for_multi_tokenizer(
|
||||
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
||||
):
|
||||
"""Write args information to share memory for multi-tokenizer"""
|
||||
# get main process ID
|
||||
main_pid = get_main_process_id()
|
||||
current_pid = os.getpid()
|
||||
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
||||
|
||||
# Write port_args to shared memory
|
||||
port_args_shm = write_to_shared_memory(
|
||||
serialize_port_args(port_args), f"port_args_{current_pid}"
|
||||
)
|
||||
# Write server_args to shared memory
|
||||
server_args_shm = write_to_shared_memory(
|
||||
serialize_server_args(server_args), f"server_args_{current_pid}"
|
||||
)
|
||||
# Write scheduler_info to shared memory
|
||||
scheduler_info_shm = write_to_shared_memory(
|
||||
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
||||
)
|
||||
|
||||
port_args_shm.close()
|
||||
server_args_shm.close()
|
||||
scheduler_info_shm.close()
|
||||
|
||||
return port_args_shm, server_args_shm, scheduler_info_shm
|
||||
|
||||
|
||||
def init_multi_tokenizer() -> ServerArgs:
|
||||
"""Read args information from shm and init tokenizer manager for current process"""
|
||||
pid = os.getpid()
|
||||
main_pid = get_main_process_id()
|
||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
||||
|
||||
# Read port_args, server_args, and scheduler_info from shared memory
|
||||
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
|
||||
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
|
||||
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
|
||||
port_args = deserialize_port_args(port_args_data)
|
||||
server_args = deserialize_server_args(server_args_data)
|
||||
scheduler_info = deserialize_scheduler_info(scheduler_info_data)
|
||||
|
||||
port_args.tokenizer_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
)
|
||||
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, False)
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
# register multi tokenizer
|
||||
tokenizer_manager.register_to_main_tokenizer_manager()
|
||||
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
template_manager=template_manager,
|
||||
scheduler_info=scheduler_info,
|
||||
)
|
||||
)
|
||||
return server_args
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args = getattr(fast_api_app, "server_args", None)
|
||||
if server_args is None:
|
||||
# for multi-tokenizer
|
||||
fast_api_app.server_args = init_multi_tokenizer()
|
||||
fast_api_app.warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
fast_api_app.server_args,
|
||||
None, # pipe_finish_writer not needed in worker
|
||||
None, # launch_callback not needed in worker
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize OpenAI serving handlers
|
||||
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
@@ -191,7 +352,15 @@ async def lifespan(fast_api_app: FastAPI):
|
||||
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
yield
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
pid = os.getpid()
|
||||
logger.info(f"uvicorn worker {pid} ending...")
|
||||
warmup_thread.join()
|
||||
logger.info(f"uvicorn {pid} ended")
|
||||
|
||||
|
||||
# Fast API
|
||||
@@ -208,6 +377,30 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
# Function to setup all middlewares for multi-process compatibility
|
||||
def setup_middlewares():
|
||||
"""Setup all middlewares for both single and multi-process modes"""
|
||||
worker_pid = os.getpid()
|
||||
|
||||
# Setup API key middleware
|
||||
api_key = os.environ.get("SGLANG_API_KEY", "")
|
||||
if api_key:
|
||||
add_api_key_middleware(app, api_key)
|
||||
logger.info(f"Worker {worker_pid} added API key middleware")
|
||||
|
||||
# Setup prometheus middleware
|
||||
# Check if metrics are enabled via environment variable
|
||||
enable_metrics = get_bool_env_var("SGLANG_ENABLE_METRICS", "false")
|
||||
if enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
||||
|
||||
|
||||
# Call setup function at module level for multi-process compatibility
|
||||
setup_middlewares()
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def validation_exception_handler(request: Request, exc: HTTPException):
|
||||
"""Enrich HTTP exception with status code and other details"""
|
||||
@@ -993,9 +1186,19 @@ def launch_server(
|
||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args
|
||||
)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
port_args.tokenizer_worker_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
)
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args, port_args=port_args
|
||||
)
|
||||
else:
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
@@ -1004,42 +1207,83 @@ def launch_server(
|
||||
)
|
||||
)
|
||||
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
# Set environment variables for middlewares in main process
|
||||
if server_args.api_key:
|
||||
os.environ["SGLANG_API_KEY"] = server_args.api_key
|
||||
logger.info("Main process set SGLANG_API_KEY")
|
||||
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
if server_args.enable_metrics:
|
||||
os.environ["SGLANG_ENABLE_METRICS"] = "true"
|
||||
logger.info("Main process set SGLANG_ENABLE_METRICS=true")
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
port_args_shm, server_args_shm, scheduler_info_shm = (
|
||||
write_data_for_multi_tokenizer(
|
||||
port_args,
|
||||
server_args,
|
||||
scheduler_info,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
set_uvicorn_logging_configs()
|
||||
app.server_args = server_args
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
}
|
||||
uvicorn.run(
|
||||
"sglang.srt.entrypoints.http_server:app",
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
workers=server_args.tokenizer_worker_num,
|
||||
)
|
||||
else:
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
warmup_thread.join()
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args_shm.unlink()
|
||||
server_args_shm.unlink()
|
||||
scheduler_info_shm.unlink()
|
||||
else:
|
||||
warmup_thread.join()
|
||||
|
||||
|
||||
def _execute_server_warmup(
|
||||
|
||||
Reference in New Issue
Block a user