Signed-off-by: ybyang <ybyang7@iflytek.com> Signed-off-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com> Co-authored-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -814,18 +815,24 @@ def _launch_subprocesses(
|
||||
),
|
||||
)
|
||||
detoken_proc.start()
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
# Launch multi-tokenizer router
|
||||
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
||||
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
# Initialize templates
|
||||
template_manager = None
|
||||
else:
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
|
||||
# Initialize templates
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
# Initialize templates
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
scheduler_infos = []
|
||||
|
||||
@@ -23,6 +23,7 @@ import json
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightVersionReqInput,
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||
MultiTokenizerManager,
|
||||
deserialize_data,
|
||||
get_main_process_id,
|
||||
read_from_shared_memory,
|
||||
write_data_for_multi_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.template_manager import TemplateManager
|
||||
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
@@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState):
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
# Function to set up all middlewares for multi-tokenizer compatibility
|
||||
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
|
||||
"""Setup all middlewares for both single and multi-process modes"""
|
||||
worker_pid = os.getpid()
|
||||
|
||||
if api_key:
|
||||
add_api_key_middleware(app, api_key)
|
||||
logger.info(f"Worker {worker_pid} added API key middleware")
|
||||
|
||||
if enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
||||
|
||||
|
||||
async def init_multi_tokenizer() -> ServerArgs:
|
||||
"""Read args information from shm and init tokenizer manager for current process"""
|
||||
pid = os.getpid()
|
||||
main_pid = get_main_process_id()
|
||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
||||
|
||||
# Read configuration from shared memory
|
||||
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
|
||||
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
|
||||
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
|
||||
port_args, server_args = deserialize_data(port_args_data, server_args_data)
|
||||
scheduler_info = scheduler_info_data
|
||||
|
||||
port_args.tokenizer_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
)
|
||||
|
||||
# Launch multi-tokenizer manager process
|
||||
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
|
||||
template_manager = TemplateManager()
|
||||
template_manager.initialize_templates(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
model_path=server_args.model_path,
|
||||
chat_template=server_args.chat_template,
|
||||
completion_template=server_args.completion_template,
|
||||
)
|
||||
# Register this tokenizer with the main tokenizer manager
|
||||
await tokenizer_manager.register_to_main_tokenizer_manager()
|
||||
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
template_manager=template_manager,
|
||||
scheduler_info=scheduler_info,
|
||||
)
|
||||
)
|
||||
return server_args
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args = getattr(fast_api_app, "server_args", None)
|
||||
if server_args is None:
|
||||
# Initialize multi-tokenizer support for worker processes
|
||||
fast_api_app.server_args = await init_multi_tokenizer()
|
||||
setup_middlewares(
|
||||
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
|
||||
)
|
||||
fast_api_app.warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
fast_api_app.server_args,
|
||||
None, # pipe_finish_writer not needed in worker
|
||||
None, # launch_callback not needed in worker
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize OpenAI serving handlers
|
||||
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
@@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI):
|
||||
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
yield
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
pid = os.getpid()
|
||||
logger.info(f"uvicorn worker {pid} ending...")
|
||||
warmup_thread.join()
|
||||
logger.info(f"uvicorn worker {pid} ended.")
|
||||
|
||||
|
||||
# Fast API
|
||||
@@ -1078,9 +1165,19 @@ def launch_server(
|
||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args
|
||||
)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
port_args.tokenizer_worker_ipc_name = (
|
||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||
)
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args, port_args=port_args
|
||||
)
|
||||
else:
|
||||
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
@@ -1089,42 +1186,75 @@ def launch_server(
|
||||
)
|
||||
)
|
||||
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args_shm, server_args_shm, scheduler_info_shm = (
|
||||
write_data_for_multi_tokenizer(
|
||||
port_args,
|
||||
server_args,
|
||||
scheduler_info,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
set_uvicorn_logging_configs()
|
||||
app.server_args = server_args
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
}
|
||||
uvicorn.run(
|
||||
"sglang.srt.entrypoints.http_server:app",
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
workers=server_args.tokenizer_worker_num,
|
||||
)
|
||||
else:
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
warmup_thread.join()
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
port_args_shm.unlink()
|
||||
server_args_shm.unlink()
|
||||
scheduler_info_shm.unlink()
|
||||
_global_state.tokenizer_manager.clear_tokenizer_mapping()
|
||||
else:
|
||||
warmup_thread.join()
|
||||
|
||||
|
||||
def _execute_server_warmup(
|
||||
|
||||
@@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
FreezeGCReq,
|
||||
MultiTokenizerRegisterReq,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
freeze_gc,
|
||||
get_worker_ids_from_req_rids,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
@@ -67,7 +70,7 @@ class DecodeStatus:
|
||||
sent_offset: int = 0
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
class DetokenizerManager(MultiTokenizerMixin):
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
def __init__(
|
||||
@@ -102,6 +105,7 @@ class DetokenizerManager:
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
(MultiTokenizerRegisterReq, lambda x: x),
|
||||
(FreezeGCReq, self.handle_freeze_gc_req),
|
||||
]
|
||||
)
|
||||
@@ -116,6 +120,39 @@ class DetokenizerManager:
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def multi_tokenizer_manager_event_loop(self):
|
||||
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
||||
self.create_sockets_mapping()
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
if output is None:
|
||||
continue
|
||||
# Extract worker_id from rid
|
||||
if isinstance(recv_obj.rids, list):
|
||||
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
||||
)
|
||||
|
||||
# Send data using the corresponding socket
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
||||
logger.info(
|
||||
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if worker_id not in self.tokenizer_mapping:
|
||||
logger.error(
|
||||
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
||||
)
|
||||
continue
|
||||
new_output = self._handle_output_by_index(output, i)
|
||||
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
||||
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
@@ -285,8 +322,12 @@ def run_detokenizer_process(
|
||||
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
manager.event_loop()
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
manager.multi_tokenizer_manager_event_loop()
|
||||
else:
|
||||
manager.event_loop()
|
||||
except Exception:
|
||||
manager.clear_tokenizer_mapping()
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
||||
parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@@ -983,6 +983,11 @@ class AbortReq:
|
||||
abort_all: bool = False
|
||||
# The finished reason data
|
||||
finished_reason: Optional[Dict[str, Any]] = None
|
||||
# used in MultiTokenzierManager mode
|
||||
rids: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.rids = self.rid
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1183,6 +1188,18 @@ class LoRAUpdateResult:
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerRegisterReq:
|
||||
rids: Optional[Union[List[str], str]] = None
|
||||
ipc_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTokenizerWarpper:
|
||||
worker_id: int
|
||||
obj: Optional[Any] = None
|
||||
|
||||
|
||||
class BlockReqType(Enum):
|
||||
BLOCK = 1
|
||||
UNBLOCK = 2
|
||||
|
||||
591
python/sglang/srt/managers/multi_tokenizer_mixin.py
Normal file
591
python/sglang/srt/managers/multi_tokenizer_mixin.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Dict
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWarpper,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_worker_ids_from_req_rids,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiTokenizerMixin:
|
||||
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
||||
|
||||
def create_sockets_mapping(self):
|
||||
if not hasattr(self, "tokenizer_mapping"):
|
||||
self.tokenizer_mapping = {}
|
||||
# Create ZMQ context if needed
|
||||
if not hasattr(self, "_zmq_context"):
|
||||
self._zmq_context = zmq.Context()
|
||||
|
||||
def init_tokenizer_mapping(
|
||||
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
||||
):
|
||||
"""init tokenizer mapping from register request"""
|
||||
ipc_name = recv_obj.ipc_name
|
||||
worker_id_int = int(worker_id)
|
||||
|
||||
if worker_id_int not in self.tokenizer_mapping:
|
||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
||||
self.tokenizer_mapping[worker_id_int] = socket
|
||||
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def register_tokenizer_ipc(self, recv_obj, worker_id):
|
||||
if worker_id not in self.tokenizer_mapping:
|
||||
# register the worker if not already done
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
return self.init_tokenizer_mapping(recv_obj, worker_id)
|
||||
else:
|
||||
logger.error(
|
||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
||||
"Please ensure the worker is registered correctly."
|
||||
)
|
||||
return False
|
||||
|
||||
def _handle_output_by_index(self, output, i):
|
||||
"""NOTE: A maintainable method is better here."""
|
||||
if isinstance(output, BatchTokenIDOut):
|
||||
new_output = BatchTokenIDOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
if len(output.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
decoded_texts=(
|
||||
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
||||
),
|
||||
decode_ids=(
|
||||
[output.decode_ids[i]] if len(output.decode_ids) > i else None
|
||||
),
|
||||
read_offsets=(
|
||||
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
||||
),
|
||||
output_ids=(
|
||||
[output.output_ids[i]]
|
||||
if output.output_ids and len(output.output_ids) > i
|
||||
else None
|
||||
),
|
||||
skip_special_tokens=(
|
||||
[output.skip_special_tokens[i]]
|
||||
if len(output.skip_special_tokens) > i
|
||||
else None
|
||||
),
|
||||
spaces_between_special_tokens=(
|
||||
[output.spaces_between_special_tokens[i]]
|
||||
if len(output.spaces_between_special_tokens) > i
|
||||
else None
|
||||
),
|
||||
no_stop_trim=(
|
||||
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
||||
),
|
||||
prompt_tokens=(
|
||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||
),
|
||||
completion_tokens=(
|
||||
[output.completion_tokens[i]]
|
||||
if len(output.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
cached_tokens=(
|
||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||
),
|
||||
spec_verify_ct=(
|
||||
[output.spec_verify_ct[i]]
|
||||
if len(output.spec_verify_ct) > i
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_val=(
|
||||
[output.input_token_logprobs_val[i]]
|
||||
if output.input_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_idx=(
|
||||
[output.input_token_logprobs_idx[i]]
|
||||
if output.input_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_val=(
|
||||
[output.output_token_logprobs_val[i]]
|
||||
if output.output_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_idx=(
|
||||
[output.output_token_logprobs_idx[i]]
|
||||
if output.output_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_val=(
|
||||
[output.input_top_logprobs_val[i]]
|
||||
if output.input_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_idx=(
|
||||
[output.input_top_logprobs_idx[i]]
|
||||
if output.input_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_val=(
|
||||
[output.output_top_logprobs_val[i]]
|
||||
if output.output_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_idx=(
|
||||
[output.output_top_logprobs_idx[i]]
|
||||
if output.output_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_val=(
|
||||
[output.input_token_ids_logprobs_val[i]]
|
||||
if output.input_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_idx=(
|
||||
[output.input_token_ids_logprobs_idx[i]]
|
||||
if output.input_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_val=(
|
||||
[output.output_token_ids_logprobs_val[i]]
|
||||
if output.output_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_idx=(
|
||||
[output.output_token_ids_logprobs_idx[i]]
|
||||
if output.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_hidden_states=(
|
||||
[output.output_hidden_states[i]]
|
||||
if output.output_hidden_states
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(output, BatchEmbeddingOut):
|
||||
new_output = BatchEmbeddingOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
if len(output.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
embeddings=(
|
||||
[output.embeddings[i]] if len(output.embeddings) > i else None
|
||||
),
|
||||
prompt_tokens=(
|
||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||
),
|
||||
cached_tokens=(
|
||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||
),
|
||||
)
|
||||
elif isinstance(output, BatchStrOut):
|
||||
new_output = BatchStrOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
if len(output.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
output_strs=(
|
||||
[output.output_strs[i]] if len(output.output_strs) > i else None
|
||||
),
|
||||
output_ids=(
|
||||
[output.output_ids[i]]
|
||||
if output.output_ids and len(output.output_ids) > i
|
||||
else None
|
||||
),
|
||||
prompt_tokens=(
|
||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||
),
|
||||
completion_tokens=(
|
||||
[output.completion_tokens[i]]
|
||||
if len(output.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
cached_tokens=(
|
||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||
),
|
||||
spec_verify_ct=(
|
||||
[output.spec_verify_ct[i]]
|
||||
if len(output.spec_verify_ct) > i
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_val=(
|
||||
[output.input_token_logprobs_val[i]]
|
||||
if output.input_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_idx=(
|
||||
[output.input_token_logprobs_idx[i]]
|
||||
if output.input_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_val=(
|
||||
[output.output_token_logprobs_val[i]]
|
||||
if output.output_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_idx=(
|
||||
[output.output_token_logprobs_idx[i]]
|
||||
if output.output_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_val=(
|
||||
[output.input_top_logprobs_val[i]]
|
||||
if output.input_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_idx=(
|
||||
[output.input_top_logprobs_idx[i]]
|
||||
if output.input_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_val=(
|
||||
[output.output_top_logprobs_val[i]]
|
||||
if output.output_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_idx=(
|
||||
[output.output_top_logprobs_idx[i]]
|
||||
if output.output_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_val=(
|
||||
[output.input_token_ids_logprobs_val[i]]
|
||||
if output.input_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_idx=(
|
||||
[output.input_token_ids_logprobs_idx[i]]
|
||||
if output.input_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_val=(
|
||||
[output.output_token_ids_logprobs_val[i]]
|
||||
if output.output_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_idx=(
|
||||
[output.output_token_ids_logprobs_idx[i]]
|
||||
if output.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_hidden_states=(
|
||||
[output.output_hidden_states[i]]
|
||||
if output.output_hidden_states
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(output, BatchMultimodalOut):
|
||||
new_output = BatchMultimodalOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
if len(output.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
||||
prompt_tokens=(
|
||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||
),
|
||||
completion_tokens=(
|
||||
[output.completion_tokens[i]]
|
||||
if len(output.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
cached_tokens=(
|
||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
new_output = output
|
||||
return new_output
|
||||
|
||||
def clear_tokenizer_mapping(self):
|
||||
if hasattr(self, "tokenizer_mapping"):
|
||||
for socket in self.tokenizer_mapping.values():
|
||||
try:
|
||||
socket.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close socket: {e}")
|
||||
self.tokenizer_mapping.clear()
|
||||
|
||||
|
||||
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
||||
"""A router to receive requests from MultiTokenizerManager"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
self.server_args = server_args
|
||||
context = zmq.asyncio.Context(3)
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
self.receive_from_worker = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
||||
)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._task = asyncio.run_coroutine_threadsafe(
|
||||
self.router_worker_obj(), self._loop
|
||||
)
|
||||
# Start handle_loop simultaneously
|
||||
self._handle_task = asyncio.run_coroutine_threadsafe(
|
||||
print_exception_wrapper(self.handle_loop), self._loop
|
||||
)
|
||||
self.init_disaggregation()
|
||||
|
||||
def _run_loop(self):
|
||||
self._loop.run_forever()
|
||||
|
||||
async def router_worker_obj(self):
|
||||
while True:
|
||||
recv_obj = await self.receive_from_worker.recv_pyobj()
|
||||
await self.send_to_scheduler.send_pyobj(recv_obj)
|
||||
|
||||
async def handle_loop(self):
|
||||
# special reqs will recv from scheduler, need to route to right worker
|
||||
self.create_sockets_mapping()
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
await self._distribute_result_to_workers(recv_obj)
|
||||
|
||||
async def _distribute_result_to_workers(self, recv_obj):
|
||||
"""Distribute result to corresponding workers based on rid"""
|
||||
if isinstance(recv_obj, MultiTokenizerWarpper):
|
||||
worker_ids = [recv_obj.worker_id]
|
||||
recv_obj = recv_obj.obj
|
||||
else:
|
||||
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
||||
|
||||
if len(worker_ids) == 0:
|
||||
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
||||
return
|
||||
|
||||
# Distribute result to each worker
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
||||
logger.info(
|
||||
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if worker_id not in self.tokenizer_mapping:
|
||||
logger.error(
|
||||
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
||||
)
|
||||
continue
|
||||
new_recv_obj = self._handle_output_by_index(recv_obj, i)
|
||||
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
||||
|
||||
|
||||
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
||||
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# prevent init prefill bootstrapserver again
|
||||
disaggregation_mode = server_args.disaggregation_mode
|
||||
server_args.disaggregation_mode = "null"
|
||||
super().__init__(server_args, port_args)
|
||||
|
||||
self.worker_id = os.getpid()
|
||||
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||
|
||||
# For PD disaggregtion
|
||||
self.server_args.disaggregation_mode = disaggregation_mode
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Communicator
|
||||
self.register_multi_tokenizer_communicator = _Communicator(
|
||||
self.send_to_scheduler, 2
|
||||
)
|
||||
self._result_dispatcher._mapping.append(
|
||||
(
|
||||
MultiTokenizerRegisterReq,
|
||||
self.register_multi_tokenizer_communicator.handle_recv,
|
||||
)
|
||||
)
|
||||
|
||||
async def register_to_main_tokenizer_manager(self):
|
||||
"""Register this worker to the main TokenizerManager"""
|
||||
# create a handle loop to receive messages from the main TokenizerManager
|
||||
self.auto_create_handle_loop()
|
||||
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
|
||||
req.ipc_name = self.tokenizer_ipc_name
|
||||
_Communicator.enable_multi_tokenizer = True
|
||||
await self.register_multi_tokenizer_communicator(req)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(
|
||||
func.__self__, MultiTokenizerRouter
|
||||
):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def serialize_port_args(port_args: PortArgs) -> dict:
|
||||
"""Serialize PortArgs into a shareable dictionary"""
|
||||
return {
|
||||
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
||||
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
||||
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
||||
"nccl_port": port_args.nccl_port,
|
||||
"rpc_ipc_name": port_args.rpc_ipc_name,
|
||||
"metrics_ipc_name": port_args.metrics_ipc_name,
|
||||
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
||||
}
|
||||
|
||||
|
||||
def deserialize_data(port_args: dict, server_args: dict):
|
||||
"""Deserialize data from shared dictionaries"""
|
||||
return PortArgs(**port_args), ServerArgs(**server_args)
|
||||
|
||||
|
||||
def serialize_server_args(server_args: ServerArgs) -> dict:
|
||||
"""Serialize ServerArgs into a shareable dictionary"""
|
||||
return dataclasses.asdict(server_args)
|
||||
|
||||
|
||||
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
||||
"""Serialize scheduler_info into a shareable dictionary"""
|
||||
return scheduler_info
|
||||
|
||||
|
||||
def deserialize_scheduler_info(data: dict) -> Dict:
|
||||
"""Deserialize scheduler_info from a shared dictionary"""
|
||||
return data
|
||||
|
||||
|
||||
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
||||
"""Write data to shared memory"""
|
||||
serialized = json.dumps(data).encode("utf-8")
|
||||
size = len(serialized)
|
||||
try:
|
||||
# Try to open existing shared memory
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
# If size is insufficient, close and recreate
|
||||
if shm.size < size:
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
||||
except FileNotFoundError:
|
||||
# If not present, create new shared memory
|
||||
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
||||
|
||||
shm.buf[:size] = serialized
|
||||
return shm
|
||||
|
||||
|
||||
def read_from_shared_memory(name: str) -> dict:
|
||||
"""Read data from shared memory"""
|
||||
try:
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
||||
shm.close()
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Shared memory {name} not found")
|
||||
|
||||
|
||||
def get_main_process_id() -> int:
|
||||
"""Get the main process ID"""
|
||||
return multiprocessing.current_process()._parent_pid
|
||||
|
||||
|
||||
def write_data_for_multi_tokenizer(
|
||||
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
||||
):
|
||||
"""Write args information to share memory for multi-tokenizer"""
|
||||
# get main process ID
|
||||
main_pid = get_main_process_id()
|
||||
current_pid = os.getpid()
|
||||
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
||||
|
||||
# Write port_args to shared memory
|
||||
port_args_shm = write_to_shared_memory(
|
||||
serialize_port_args(port_args), f"port_args_{current_pid}"
|
||||
)
|
||||
# Write server_args to shared memory
|
||||
server_args_shm = write_to_shared_memory(
|
||||
serialize_server_args(server_args), f"server_args_{current_pid}"
|
||||
)
|
||||
# Write scheduler_info to shared memory
|
||||
scheduler_info_shm = write_to_shared_memory(
|
||||
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
||||
)
|
||||
|
||||
port_args_shm.close()
|
||||
server_args_shm.close()
|
||||
scheduler_info_shm.close()
|
||||
|
||||
return port_args_shm, server_args_shm, scheduler_info_shm
|
||||
@@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWarpper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -257,7 +259,6 @@ class Scheduler(
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.idle_sleeper = None
|
||||
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
@@ -540,6 +541,7 @@ class Scheduler(
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1101,6 +1103,17 @@ class Scheduler(
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
continue
|
||||
|
||||
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
||||
if isinstance(recv_req, MultiTokenizerWarpper):
|
||||
worker_id = recv_req.worker_id
|
||||
recv_req = recv_req.obj
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
output = MultiTokenizerWarpper(worker_id, output)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
continue
|
||||
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
if isinstance(output, RpcReqOutput):
|
||||
@@ -2474,6 +2487,10 @@ class Scheduler(
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
return result
|
||||
|
||||
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
return recv_req
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
t = recv_req.forward_sleep_time
|
||||
if t is not None and t <= 0:
|
||||
|
||||
@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerWarpper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -131,6 +132,7 @@ from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
freeze_gc,
|
||||
get_bool_env_var,
|
||||
get_origin_rid,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
@@ -266,9 +268,15 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
||||
)
|
||||
else:
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Request states
|
||||
self.no_create_loop = False
|
||||
@@ -312,35 +320,7 @@ class TokenizerManager:
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For PD disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
self.init_disaggregation()
|
||||
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
@@ -488,6 +468,37 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -497,6 +508,15 @@ class TokenizerManager:
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Modify rid, add worker_id
|
||||
if isinstance(obj.rid, list):
|
||||
# If it's an array, add worker_id prefix to each element
|
||||
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
||||
else:
|
||||
# If it's a single value, add worker_id prefix
|
||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
@@ -1096,6 +1116,8 @@ class TokenizerManager:
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str]:
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
@@ -1315,6 +1337,8 @@ class TokenizerManager:
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
@@ -1590,7 +1614,6 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
@@ -1610,9 +1633,12 @@ class TokenizerManager:
|
||||
)
|
||||
continue
|
||||
|
||||
origin_rid = rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(rid)
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"id": origin_rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
"weight_version": self.server_args.weight_version,
|
||||
@@ -1918,6 +1944,9 @@ class TokenizerManager:
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
origin_rid = recv_obj.rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(origin_rid)
|
||||
state.finished = True
|
||||
if recv_obj.finished_reason:
|
||||
out = {
|
||||
@@ -1930,7 +1959,7 @@ class TokenizerManager:
|
||||
out = {
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"id": origin_rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
@@ -2116,6 +2145,8 @@ T = TypeVar("T")
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
|
||||
@@ -128,6 +128,7 @@ class ServerArgs:
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
tokenizer_worker_num: int = 1
|
||||
skip_tokenizer_init: bool = False
|
||||
load_format: str = "auto"
|
||||
model_loader_extra_config: str = "{}"
|
||||
@@ -827,6 +828,12 @@ class ServerArgs:
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-worker-num",
|
||||
type=int,
|
||||
default=ServerArgs.tokenizer_worker_num,
|
||||
help="The worker num of the tokenizer manager.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
@@ -2176,6 +2183,9 @@ class ServerArgs:
|
||||
self.chunked_prefill_size % self.page_size == 0
|
||||
), "chunked_prefill_size must be divisible by page_size"
|
||||
|
||||
# Check multi tokenizer
|
||||
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
||||
|
||||
def check_lora_server_args(self):
|
||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||
|
||||
@@ -2419,6 +2429,9 @@ class PortArgs:
|
||||
# The ipc filename for Scheduler to send metrics
|
||||
metrics_ipc_name: str
|
||||
|
||||
# The ipc filename for Tokenizer and worker tokenizer
|
||||
tokenizer_worker_ipc_name: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||
if server_args.nccl_port is None:
|
||||
@@ -2442,6 +2455,7 @@ class PortArgs:
|
||||
nccl_port=nccl_port,
|
||||
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
tokenizer_worker_ipc_name=None,
|
||||
)
|
||||
else:
|
||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||
@@ -2475,6 +2489,7 @@ class PortArgs:
|
||||
nccl_port=nccl_port,
|
||||
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
|
||||
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
|
||||
tokenizer_worker_ipc_name=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128):
|
||||
return decorator
|
||||
|
||||
|
||||
def get_worker_ids_from_req_rids(rids):
|
||||
if isinstance(rids, list):
|
||||
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
||||
elif isinstance(rids, str):
|
||||
worker_ids = [int(rids.split("_")[0])]
|
||||
else:
|
||||
worker_ids = []
|
||||
return worker_ids
|
||||
|
||||
|
||||
def get_origin_rid(rid):
|
||||
return rid.split("_", 1)[1] if "_" in rid else rid
|
||||
|
||||
|
||||
def apply_module_patch(target_module, target_function, wrappers):
|
||||
original_module, original_function = parse_module_path(
|
||||
target_module, target_function, False
|
||||
|
||||
Reference in New Issue
Block a user