Revert "Rename TokenizerManager to StdOrchestrator" (#3828)
This commit is contained in:
@@ -241,7 +241,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback())
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
@@ -491,7 +491,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
return pixel_values, image_hash, image.size, image_grid_thws
|
||||
except Exception:
|
||||
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback())
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(self, image_data: Union[bytes, str]):
|
||||
if self.executor is not None:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""
|
||||
The definition of objects transfered between different
|
||||
processes (StdOrchestrator, DetokenizerManager, Controller).
|
||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
@@ -174,7 +174,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the StdOrchestrator
|
||||
# Directly send to the TokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
480
python/sglang/srt/managers/tokenizer_manager.py
Normal file
480
python/sglang/srt/managers/tokenizer_manager.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.managers.generation_manager import GenerationManager
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
self._generation_manager = GenerationManager(
|
||||
server_args=server_args,
|
||||
on_request=self.send_to_scheduler.send_pyobj,
|
||||
)
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.release_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(
|
||||
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
||||
self._generation_manager.handle_batch_output,
|
||||
),
|
||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||
(
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
self.release_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
async with self.model_update_lock.reader_lock:
|
||||
async for value in self._generation_manager.generate_request(obj, request):
|
||||
yield value
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
self._generation_manager.abort_request(rid)
|
||||
|
||||
def start_profile(self):
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_profile(self):
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||
|
||||
if True:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
return await self._wait_for_model_update_from_disk(obj)
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str]:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.served_model_name = obj.model_path
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
all_success = all([r.success for r in result])
|
||||
if all_success is True:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
return all_success, all_message
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
self,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_tensor(
|
||||
self,
|
||||
obj: UpdateWeightsFromTensorReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
results = await self.get_weights_by_name_communicator(obj)
|
||||
all_parameters = [r.parameter for r in results]
|
||||
if self.server_args.dp_size == 1:
|
||||
return all_parameters[0]
|
||||
else:
|
||||
return all_parameters
|
||||
|
||||
async def release_memory_occupation(
|
||||
self,
|
||||
obj: ReleaseMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.release_memory_occupation_communicator(obj)
|
||||
|
||||
async def resume_memory_occupation(
|
||||
self,
|
||||
obj: ResumeMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.resume_memory_occupation_communicator(obj)
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
if obj.session_id is None:
|
||||
obj.session_id = uuid.uuid4().hex
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
session_id = await self.session_futures[obj.session_id]
|
||||
del self.session_futures[obj.session_id]
|
||||
return session_id
|
||||
|
||||
async def close_session(
|
||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||
self._generation_manager.configure_logging(obj)
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
await asyncio.sleep(1)
|
||||
if obj.is_single:
|
||||
self.abort_request(obj.rid)
|
||||
else:
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(abort_request)
|
||||
return background_tasks
|
||||
|
||||
def auto_create_handle_loop(self):
|
||||
if self.no_create_loop:
|
||||
return
|
||||
|
||||
self.no_create_loop = True
|
||||
loop = asyncio.get_event_loop()
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
# We cannot add signal handler when the tokenizer manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal_handler = SignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
||||
else:
|
||||
logger.warning(
|
||||
"Signal handler is not added because the tokenizer manager is "
|
||||
"not in the main thread. This disables graceful shutdown of the "
|
||||
"tokenizer manager when SIGTERM is received."
|
||||
)
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Drain requests
|
||||
while True:
|
||||
remain_num_req = len(self._generation_manager.rid_to_state)
|
||||
logger.info(
|
||||
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
||||
)
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(0)
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
)
|
||||
|
||||
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
|
||||
@property
|
||||
def is_generation(self):
|
||||
return self._generation_manager.model_config.is_generation
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self._generation_manager.tokenizer
|
||||
|
||||
@property
|
||||
def image_token_id(self):
|
||||
return self._generation_manager.model_config.image_token_id
|
||||
|
||||
def configure_max_req_input_len(self, max_req_input_len):
|
||||
self._generation_manager.generation_converter.max_req_input_len = (
|
||||
max_req_input_len
|
||||
)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
def __init__(self, tokenizer_manager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
|
||||
def signal_handler(self, signum=None, frame=None):
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||
)
|
||||
self.tokenizer_manager.gracefully_exit = True
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_future: Optional[asyncio.Future] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
|
||||
async def __call__(self, obj):
|
||||
self._sender.send_pyobj(obj)
|
||||
self._result_future = asyncio.Future()
|
||||
self._result_values = []
|
||||
await self._result_future
|
||||
result_values = self._result_values
|
||||
self._result_future = self._result_values = None
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_future.set_result(None)
|
||||
Reference in New Issue
Block a user