[1/2] Refactor multi-tokenizer manager (#10074)
This commit is contained in:
@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _init_tokenizer_manager(
|
||||
server_args: ServerArgs, port_args: PortArgs
|
||||
) -> TokenizerManager:
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
|
||||
# Initialize templates
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
|
||||
return tokenizer_manager, template_manager
|
||||
|
||||
|
||||
def _launch_subprocesses(
|
||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
||||
@@ -816,23 +834,15 @@ def _launch_subprocesses(
|
||||
),
|
||||
)
|
||||
detoken_proc.start()
|
||||
|
||||
# Init tokenizer manager first, as the bootstrap server is initialized here
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
# Launch multi-tokenizer router
|
||||
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
||||
|
||||
# Initialize templates
|
||||
template_manager = None
|
||||
else:
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
|
||||
# Initialize templates
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
tokenizer_manager, template_manager = _init_tokenizer_manager(
|
||||
server_args, port_args
|
||||
)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
@@ -856,5 +866,7 @@ def _launch_subprocesses(
|
||||
|
||||
# Assume all schedulers have the same scheduler_info
|
||||
scheduler_info = scheduler_infos[0]
|
||||
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
|
||||
return tokenizer_manager, template_manager, scheduler_info
|
||||
|
||||
@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||
MultiTokenizerManager,
|
||||
deserialize_data,
|
||||
get_main_process_id,
|
||||
read_from_shared_memory,
|
||||
write_data_for_multi_tokenizer,
|
||||
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
# Function to set up all middlewares for multi-tokenizer compatibility
|
||||
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
|
||||
"""Setup all middlewares for both single and multi-process modes"""
|
||||
worker_pid = os.getpid()
|
||||
|
||||
if api_key:
|
||||
add_api_key_middleware(app, api_key)
|
||||
logger.info(f"Worker {worker_pid} added API key middleware")
|
||||
|
||||
if enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
||||
|
||||
|
||||
async def init_multi_tokenizer() -> ServerArgs:
|
||||
"""Read args information from shm and init tokenizer manager for current process"""
|
||||
pid = os.getpid()
|
||||
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
|
||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
||||
|
||||
# Read configuration from shared memory
|
||||
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
|
||||
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
|
||||
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
|
||||
port_args, server_args = deserialize_data(port_args_data, server_args_data)
|
||||
scheduler_info = scheduler_info_data
|
||||
port_args, server_args, scheduler_info = read_from_shared_memory(
|
||||
f"multi_tokenizer_args_{main_pid}"
|
||||
)
|
||||
server_args: ServerArgs
|
||||
|
||||
# API key authentication is not supported in multi-tokenizer mode
|
||||
assert (
|
||||
server_args.api_key is None
|
||||
), "API key is not supported in multi-tokenizer mode"
|
||||
|
||||
port_args.tokenizer_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args = getattr(fast_api_app, "server_args", None)
|
||||
if server_args is None:
|
||||
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
|
||||
# Initialize multi-tokenizer support for worker processes
|
||||
fast_api_app.server_args = await init_multi_tokenizer()
|
||||
setup_middlewares(
|
||||
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
|
||||
)
|
||||
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
|
||||
|
||||
# only metrics middleware is supported in multi-tokenizer mode
|
||||
worker_pid = os.getpid()
|
||||
if fast_api_app.server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
||||
fast_api_app.warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
@@ -1187,12 +1179,10 @@ def launch_server(
|
||||
)
|
||||
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args_shm, server_args_shm, scheduler_info_shm = (
|
||||
write_data_for_multi_tokenizer(
|
||||
port_args,
|
||||
server_args,
|
||||
scheduler_info,
|
||||
)
|
||||
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
|
||||
port_args,
|
||||
server_args,
|
||||
scheduler_info,
|
||||
)
|
||||
else:
|
||||
# Add api key authorization
|
||||
@@ -1239,6 +1229,7 @@ def launch_server(
|
||||
workers=server_args.tokenizer_worker_num,
|
||||
)
|
||||
else:
|
||||
app.is_single_tokenizer_mode = True
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -1249,10 +1240,8 @@ def launch_server(
|
||||
)
|
||||
finally:
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args_shm.unlink()
|
||||
server_args_shm.unlink()
|
||||
scheduler_info_shm.unlink()
|
||||
_global_state.tokenizer_manager.clear_tokenizer_mapping()
|
||||
multi_tokenizer_args_shm.unlink()
|
||||
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
|
||||
else:
|
||||
warmup_thread.join()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user