[1/2] Refactor multi-tokenizer manager (#10074)
This commit is contained in:
@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_tokenizer_manager(
|
||||||
|
server_args: ServerArgs, port_args: PortArgs
|
||||||
|
) -> TokenizerManager:
|
||||||
|
# Launch tokenizer process
|
||||||
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||||
|
|
||||||
|
# Initialize templates
|
||||||
|
template_manager = TemplateManager()
|
||||||
|
template_manager.initialize_templates(
|
||||||
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
model_path=server_args.model_path,
|
||||||
|
chat_template=server_args.chat_template,
|
||||||
|
completion_template=server_args.completion_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer_manager, template_manager
|
||||||
|
|
||||||
|
|
||||||
def _launch_subprocesses(
|
def _launch_subprocesses(
|
||||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||||
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
||||||
@@ -816,23 +834,15 @@ def _launch_subprocesses(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
detoken_proc.start()
|
detoken_proc.start()
|
||||||
|
|
||||||
|
# Init tokenizer manager first, as the bootstrap server is initialized here
|
||||||
if server_args.tokenizer_worker_num > 1:
|
if server_args.tokenizer_worker_num > 1:
|
||||||
# Launch multi-tokenizer router
|
# Launch multi-tokenizer router
|
||||||
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
||||||
|
|
||||||
# Initialize templates
|
|
||||||
template_manager = None
|
template_manager = None
|
||||||
else:
|
else:
|
||||||
# Launch tokenizer process
|
tokenizer_manager, template_manager = _init_tokenizer_manager(
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
server_args, port_args
|
||||||
|
|
||||||
# Initialize templates
|
|
||||||
template_manager = TemplateManager()
|
|
||||||
template_manager.initialize_templates(
|
|
||||||
tokenizer_manager=tokenizer_manager,
|
|
||||||
model_path=server_args.model_path,
|
|
||||||
chat_template=server_args.chat_template,
|
|
||||||
completion_template=server_args.completion_template,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the model to finish loading
|
# Wait for the model to finish loading
|
||||||
@@ -856,5 +866,7 @@ def _launch_subprocesses(
|
|||||||
|
|
||||||
# Assume all schedulers have the same scheduler_info
|
# Assume all schedulers have the same scheduler_info
|
||||||
scheduler_info = scheduler_infos[0]
|
scheduler_info = scheduler_infos[0]
|
||||||
|
|
||||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||||
|
|
||||||
return tokenizer_manager, template_manager, scheduler_info
|
return tokenizer_manager, template_manager, scheduler_info
|
||||||
|
|||||||
@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||||
MultiTokenizerManager,
|
MultiTokenizerManager,
|
||||||
deserialize_data,
|
|
||||||
get_main_process_id,
|
get_main_process_id,
|
||||||
read_from_shared_memory,
|
read_from_shared_memory,
|
||||||
write_data_for_multi_tokenizer,
|
write_data_for_multi_tokenizer,
|
||||||
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
|
|||||||
_global_state = global_state
|
_global_state = global_state
|
||||||
|
|
||||||
|
|
||||||
# Function to set up all middlewares for multi-tokenizer compatibility
|
|
||||||
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
|
|
||||||
"""Setup all middlewares for both single and multi-process modes"""
|
|
||||||
worker_pid = os.getpid()
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
add_api_key_middleware(app, api_key)
|
|
||||||
logger.info(f"Worker {worker_pid} added API key middleware")
|
|
||||||
|
|
||||||
if enable_metrics:
|
|
||||||
add_prometheus_middleware(app)
|
|
||||||
enable_func_timer()
|
|
||||||
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
|
||||||
|
|
||||||
|
|
||||||
async def init_multi_tokenizer() -> ServerArgs:
|
async def init_multi_tokenizer() -> ServerArgs:
|
||||||
"""Read args information from shm and init tokenizer manager for current process"""
|
"""Read args information from shm and init tokenizer manager for current process"""
|
||||||
pid = os.getpid()
|
pid = os.getpid()
|
||||||
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
|
|||||||
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
|
||||||
|
|
||||||
# Read configuration from shared memory
|
# Read configuration from shared memory
|
||||||
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
|
port_args, server_args, scheduler_info = read_from_shared_memory(
|
||||||
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
|
f"multi_tokenizer_args_{main_pid}"
|
||||||
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
|
)
|
||||||
port_args, server_args = deserialize_data(port_args_data, server_args_data)
|
server_args: ServerArgs
|
||||||
scheduler_info = scheduler_info_data
|
|
||||||
|
# API key authentication is not supported in multi-tokenizer mode
|
||||||
|
assert (
|
||||||
|
server_args.api_key is None
|
||||||
|
), "API key is not supported in multi-tokenizer mode"
|
||||||
|
|
||||||
port_args.tokenizer_ipc_name = (
|
port_args.tokenizer_ipc_name = (
|
||||||
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
||||||
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(fast_api_app: FastAPI):
|
async def lifespan(fast_api_app: FastAPI):
|
||||||
server_args = getattr(fast_api_app, "server_args", None)
|
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
|
||||||
if server_args is None:
|
|
||||||
# Initialize multi-tokenizer support for worker processes
|
# Initialize multi-tokenizer support for worker processes
|
||||||
fast_api_app.server_args = await init_multi_tokenizer()
|
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
|
||||||
setup_middlewares(
|
|
||||||
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
|
# only metrics middleware is supported in multi-tokenizer mode
|
||||||
)
|
worker_pid = os.getpid()
|
||||||
|
if fast_api_app.server_args.enable_metrics:
|
||||||
|
add_prometheus_middleware(app)
|
||||||
|
enable_func_timer()
|
||||||
|
|
||||||
|
logger.info(f"Worker {worker_pid} added prometheus middleware")
|
||||||
fast_api_app.warmup_thread = threading.Thread(
|
fast_api_app.warmup_thread = threading.Thread(
|
||||||
target=_wait_and_warmup,
|
target=_wait_and_warmup,
|
||||||
args=(
|
args=(
|
||||||
@@ -1187,12 +1179,10 @@ def launch_server(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if server_args.tokenizer_worker_num > 1:
|
if server_args.tokenizer_worker_num > 1:
|
||||||
port_args_shm, server_args_shm, scheduler_info_shm = (
|
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
|
||||||
write_data_for_multi_tokenizer(
|
port_args,
|
||||||
port_args,
|
server_args,
|
||||||
server_args,
|
scheduler_info,
|
||||||
scheduler_info,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Add api key authorization
|
# Add api key authorization
|
||||||
@@ -1239,6 +1229,7 @@ def launch_server(
|
|||||||
workers=server_args.tokenizer_worker_num,
|
workers=server_args.tokenizer_worker_num,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
app.is_single_tokenizer_mode = True
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host=server_args.host,
|
host=server_args.host,
|
||||||
@@ -1249,10 +1240,8 @@ def launch_server(
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if server_args.tokenizer_worker_num > 1:
|
if server_args.tokenizer_worker_num > 1:
|
||||||
port_args_shm.unlink()
|
multi_tokenizer_args_shm.unlink()
|
||||||
server_args_shm.unlink()
|
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
|
||||||
scheduler_info_shm.unlink()
|
|
||||||
_global_state.tokenizer_manager.clear_tokenizer_mapping()
|
|
||||||
else:
|
else:
|
||||||
warmup_thread.join()
|
warmup_thread.join()
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
MultiTokenizerRegisterReq,
|
MultiTokenizerRegisterReq,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
|
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_logger,
|
configure_logger,
|
||||||
@@ -69,7 +69,7 @@ class DecodeStatus:
|
|||||||
sent_offset: int = 0
|
sent_offset: int = 0
|
||||||
|
|
||||||
|
|
||||||
class DetokenizerManager(MultiTokenizerMixin):
|
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -289,11 +289,11 @@ def run_detokenizer_process(
|
|||||||
try:
|
try:
|
||||||
manager = DetokenizerManager(server_args, port_args)
|
manager = DetokenizerManager(server_args, port_args)
|
||||||
if server_args.tokenizer_worker_num > 1:
|
if server_args.tokenizer_worker_num > 1:
|
||||||
manager.multi_tokenizer_manager_event_loop()
|
manager.multi_http_worker_event_loop()
|
||||||
else:
|
else:
|
||||||
manager.event_loop()
|
manager.event_loop()
|
||||||
except Exception:
|
except Exception:
|
||||||
manager.clear_tokenizer_mapping()
|
manager.socket_mapping.clear_all_sockets()
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
||||||
parent_process.send_signal(signal.SIGQUIT)
|
parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|||||||
46
python/sglang/srt/managers/disagg_service.py
Normal file
46
python/sglang/srt/managers/disagg_service.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Start bootstrap/kv-store-related server"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
||||||
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
DisaggregationMode,
|
||||||
|
KVClassType,
|
||||||
|
TransferBackend,
|
||||||
|
get_kv_class,
|
||||||
|
)
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
def start_disagg_service(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
):
|
||||||
|
# Start kv boostrap server on prefill
|
||||||
|
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
|
||||||
|
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
|
||||||
|
|
||||||
|
if disagg_mode == DisaggregationMode.PREFILL:
|
||||||
|
# only start bootstrap server on prefill tm
|
||||||
|
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
||||||
|
transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||||
|
)
|
||||||
|
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
||||||
|
host=server_args.host,
|
||||||
|
port=server_args.disaggregation_bootstrap_port,
|
||||||
|
)
|
||||||
|
is_create_store = (
|
||||||
|
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
|
||||||
|
)
|
||||||
|
if is_create_store:
|
||||||
|
try:
|
||||||
|
from mf_adapter import create_config_store
|
||||||
|
|
||||||
|
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||||
|
create_config_store(ascend_url)
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"Failed create mf store, invalid ascend_url."
|
||||||
|
error_message += f" With exception {e}"
|
||||||
|
raise error_message
|
||||||
|
|
||||||
|
return bootstrap_server
|
||||||
@@ -13,21 +13,21 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as multiprocessing
|
import multiprocessing as multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import setproctitle
|
import setproctitle
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
||||||
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchMultimodalOut,
|
BatchMultimodalOut,
|
||||||
@@ -44,302 +44,296 @@ from sglang.utils import get_exception_traceback
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiTokenizerMixin:
|
class SocketMapping:
|
||||||
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
def __init__(self):
|
||||||
|
self._zmq_context = zmq.Context()
|
||||||
|
self._mapping: Dict[str, zmq.Socket] = {}
|
||||||
|
|
||||||
def create_sockets_mapping(self):
|
def clear_all_sockets(self):
|
||||||
if not hasattr(self, "tokenizer_mapping"):
|
for socket in self._mapping.values():
|
||||||
self.tokenizer_mapping = {}
|
socket.close()
|
||||||
# Create ZMQ context if needed
|
self._mapping.clear()
|
||||||
if not hasattr(self, "_zmq_context"):
|
|
||||||
self._zmq_context = zmq.Context()
|
|
||||||
|
|
||||||
def init_tokenizer_mapping(
|
def register_ipc_mapping(
|
||||||
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
|
||||||
):
|
):
|
||||||
"""init tokenizer mapping from register request"""
|
type_str = "tokenizer" if is_tokenizer else "detokenizer"
|
||||||
ipc_name = recv_obj.ipc_name
|
if worker_id in self._mapping:
|
||||||
worker_id_int = int(worker_id)
|
logger.warning(
|
||||||
|
f"{type_str} already registered with worker {worker_id}, skipping..."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
f"{type_str} not registered with worker {worker_id}, registering..."
|
||||||
|
)
|
||||||
|
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
|
||||||
|
self._mapping[worker_id] = socket
|
||||||
|
self._mapping[worker_id].send_pyobj(recv_obj)
|
||||||
|
|
||||||
if worker_id_int not in self.tokenizer_mapping:
|
def send_output(self, worker_id: str, output: Any):
|
||||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
if worker_id not in self._mapping:
|
||||||
self.tokenizer_mapping[worker_id_int] = socket
|
logger.error(
|
||||||
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
|
f"worker ID {worker_id} not registered. Check if the server Process is alive"
|
||||||
return True
|
)
|
||||||
else:
|
return
|
||||||
return False
|
self._mapping[worker_id].send_pyobj(output)
|
||||||
|
|
||||||
def register_tokenizer_ipc(self, recv_obj, worker_id):
|
|
||||||
if worker_id not in self.tokenizer_mapping:
|
|
||||||
# register the worker if not already done
|
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
|
||||||
return self.init_tokenizer_mapping(recv_obj, worker_id)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
|
||||||
"Please ensure the worker is registered correctly."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _handle_output_by_index(self, output, i):
|
def _handle_output_by_index(output, i):
|
||||||
"""NOTE: A maintainable method is better here."""
|
"""NOTE: A maintainable method is better here."""
|
||||||
if isinstance(output, BatchTokenIDOut):
|
if isinstance(output, BatchTokenIDOut):
|
||||||
new_output = BatchTokenIDOut(
|
new_output = BatchTokenIDOut(
|
||||||
rids=[output.rids[i]],
|
rids=[output.rids[i]],
|
||||||
finished_reasons=(
|
finished_reasons=(
|
||||||
[output.finished_reasons[i]]
|
[output.finished_reasons[i]]
|
||||||
if len(output.finished_reasons) > i
|
if len(output.finished_reasons) > i
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
decoded_texts=(
|
decoded_texts=(
|
||||||
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
||||||
),
|
),
|
||||||
decode_ids=(
|
decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
|
||||||
[output.decode_ids[i]] if len(output.decode_ids) > i else None
|
read_offsets=(
|
||||||
),
|
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
||||||
read_offsets=(
|
),
|
||||||
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
output_ids=(
|
||||||
),
|
[output.output_ids[i]]
|
||||||
output_ids=(
|
if output.output_ids and len(output.output_ids) > i
|
||||||
[output.output_ids[i]]
|
else None
|
||||||
if output.output_ids and len(output.output_ids) > i
|
),
|
||||||
else None
|
skip_special_tokens=(
|
||||||
),
|
[output.skip_special_tokens[i]]
|
||||||
skip_special_tokens=(
|
if len(output.skip_special_tokens) > i
|
||||||
[output.skip_special_tokens[i]]
|
else None
|
||||||
if len(output.skip_special_tokens) > i
|
),
|
||||||
else None
|
spaces_between_special_tokens=(
|
||||||
),
|
[output.spaces_between_special_tokens[i]]
|
||||||
spaces_between_special_tokens=(
|
if len(output.spaces_between_special_tokens) > i
|
||||||
[output.spaces_between_special_tokens[i]]
|
else None
|
||||||
if len(output.spaces_between_special_tokens) > i
|
),
|
||||||
else None
|
no_stop_trim=(
|
||||||
),
|
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
||||||
no_stop_trim=(
|
),
|
||||||
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
prompt_tokens=(
|
||||||
),
|
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||||
prompt_tokens=(
|
),
|
||||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
completion_tokens=(
|
||||||
),
|
[output.completion_tokens[i]]
|
||||||
completion_tokens=(
|
if len(output.completion_tokens) > i
|
||||||
[output.completion_tokens[i]]
|
else None
|
||||||
if len(output.completion_tokens) > i
|
),
|
||||||
else None
|
cached_tokens=(
|
||||||
),
|
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||||
cached_tokens=(
|
),
|
||||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
spec_verify_ct=(
|
||||||
),
|
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
||||||
spec_verify_ct=(
|
),
|
||||||
[output.spec_verify_ct[i]]
|
input_token_logprobs_val=(
|
||||||
if len(output.spec_verify_ct) > i
|
[output.input_token_logprobs_val[i]]
|
||||||
else None
|
if output.input_token_logprobs_val
|
||||||
),
|
else None
|
||||||
input_token_logprobs_val=(
|
),
|
||||||
[output.input_token_logprobs_val[i]]
|
input_token_logprobs_idx=(
|
||||||
if output.input_token_logprobs_val
|
[output.input_token_logprobs_idx[i]]
|
||||||
else None
|
if output.input_token_logprobs_idx
|
||||||
),
|
else None
|
||||||
input_token_logprobs_idx=(
|
),
|
||||||
[output.input_token_logprobs_idx[i]]
|
output_token_logprobs_val=(
|
||||||
if output.input_token_logprobs_idx
|
[output.output_token_logprobs_val[i]]
|
||||||
else None
|
if output.output_token_logprobs_val
|
||||||
),
|
else None
|
||||||
output_token_logprobs_val=(
|
),
|
||||||
[output.output_token_logprobs_val[i]]
|
output_token_logprobs_idx=(
|
||||||
if output.output_token_logprobs_val
|
[output.output_token_logprobs_idx[i]]
|
||||||
else None
|
if output.output_token_logprobs_idx
|
||||||
),
|
else None
|
||||||
output_token_logprobs_idx=(
|
),
|
||||||
[output.output_token_logprobs_idx[i]]
|
input_top_logprobs_val=(
|
||||||
if output.output_token_logprobs_idx
|
[output.input_top_logprobs_val[i]]
|
||||||
else None
|
if output.input_top_logprobs_val
|
||||||
),
|
else None
|
||||||
input_top_logprobs_val=(
|
),
|
||||||
[output.input_top_logprobs_val[i]]
|
input_top_logprobs_idx=(
|
||||||
if output.input_top_logprobs_val
|
[output.input_top_logprobs_idx[i]]
|
||||||
else None
|
if output.input_top_logprobs_idx
|
||||||
),
|
else None
|
||||||
input_top_logprobs_idx=(
|
),
|
||||||
[output.input_top_logprobs_idx[i]]
|
output_top_logprobs_val=(
|
||||||
if output.input_top_logprobs_idx
|
[output.output_top_logprobs_val[i]]
|
||||||
else None
|
if output.output_top_logprobs_val
|
||||||
),
|
else None
|
||||||
output_top_logprobs_val=(
|
),
|
||||||
[output.output_top_logprobs_val[i]]
|
output_top_logprobs_idx=(
|
||||||
if output.output_top_logprobs_val
|
[output.output_top_logprobs_idx[i]]
|
||||||
else None
|
if output.output_top_logprobs_idx
|
||||||
),
|
else None
|
||||||
output_top_logprobs_idx=(
|
),
|
||||||
[output.output_top_logprobs_idx[i]]
|
input_token_ids_logprobs_val=(
|
||||||
if output.output_top_logprobs_idx
|
[output.input_token_ids_logprobs_val[i]]
|
||||||
else None
|
if output.input_token_ids_logprobs_val
|
||||||
),
|
else None
|
||||||
input_token_ids_logprobs_val=(
|
),
|
||||||
[output.input_token_ids_logprobs_val[i]]
|
input_token_ids_logprobs_idx=(
|
||||||
if output.input_token_ids_logprobs_val
|
[output.input_token_ids_logprobs_idx[i]]
|
||||||
else None
|
if output.input_token_ids_logprobs_idx
|
||||||
),
|
else None
|
||||||
input_token_ids_logprobs_idx=(
|
),
|
||||||
[output.input_token_ids_logprobs_idx[i]]
|
output_token_ids_logprobs_val=(
|
||||||
if output.input_token_ids_logprobs_idx
|
[output.output_token_ids_logprobs_val[i]]
|
||||||
else None
|
if output.output_token_ids_logprobs_val
|
||||||
),
|
else None
|
||||||
output_token_ids_logprobs_val=(
|
),
|
||||||
[output.output_token_ids_logprobs_val[i]]
|
output_token_ids_logprobs_idx=(
|
||||||
if output.output_token_ids_logprobs_val
|
[output.output_token_ids_logprobs_idx[i]]
|
||||||
else None
|
if output.output_token_ids_logprobs_idx
|
||||||
),
|
else None
|
||||||
output_token_ids_logprobs_idx=(
|
),
|
||||||
[output.output_token_ids_logprobs_idx[i]]
|
output_hidden_states=(
|
||||||
if output.output_token_ids_logprobs_idx
|
[output.output_hidden_states[i]]
|
||||||
else None
|
if output.output_hidden_states
|
||||||
),
|
else None
|
||||||
output_hidden_states=(
|
),
|
||||||
[output.output_hidden_states[i]]
|
)
|
||||||
if output.output_hidden_states
|
elif isinstance(output, BatchEmbeddingOut):
|
||||||
else None
|
new_output = BatchEmbeddingOut(
|
||||||
),
|
rids=[output.rids[i]],
|
||||||
)
|
finished_reasons=(
|
||||||
elif isinstance(output, BatchEmbeddingOut):
|
[output.finished_reasons[i]]
|
||||||
new_output = BatchEmbeddingOut(
|
if len(output.finished_reasons) > i
|
||||||
rids=[output.rids[i]],
|
else None
|
||||||
finished_reasons=(
|
),
|
||||||
[output.finished_reasons[i]]
|
embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
|
||||||
if len(output.finished_reasons) > i
|
prompt_tokens=(
|
||||||
else None
|
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||||
),
|
),
|
||||||
embeddings=(
|
cached_tokens=(
|
||||||
[output.embeddings[i]] if len(output.embeddings) > i else None
|
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||||
),
|
),
|
||||||
prompt_tokens=(
|
)
|
||||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
elif isinstance(output, BatchStrOut):
|
||||||
),
|
new_output = BatchStrOut(
|
||||||
cached_tokens=(
|
rids=[output.rids[i]],
|
||||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
finished_reasons=(
|
||||||
),
|
[output.finished_reasons[i]]
|
||||||
)
|
if len(output.finished_reasons) > i
|
||||||
elif isinstance(output, BatchStrOut):
|
else None
|
||||||
new_output = BatchStrOut(
|
),
|
||||||
rids=[output.rids[i]],
|
output_strs=(
|
||||||
finished_reasons=(
|
[output.output_strs[i]] if len(output.output_strs) > i else None
|
||||||
[output.finished_reasons[i]]
|
),
|
||||||
if len(output.finished_reasons) > i
|
output_ids=(
|
||||||
else None
|
[output.output_ids[i]]
|
||||||
),
|
if output.output_ids and len(output.output_ids) > i
|
||||||
output_strs=(
|
else None
|
||||||
[output.output_strs[i]] if len(output.output_strs) > i else None
|
),
|
||||||
),
|
prompt_tokens=(
|
||||||
output_ids=(
|
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||||
[output.output_ids[i]]
|
),
|
||||||
if output.output_ids and len(output.output_ids) > i
|
completion_tokens=(
|
||||||
else None
|
[output.completion_tokens[i]]
|
||||||
),
|
if len(output.completion_tokens) > i
|
||||||
prompt_tokens=(
|
else None
|
||||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
),
|
||||||
),
|
cached_tokens=(
|
||||||
completion_tokens=(
|
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||||
[output.completion_tokens[i]]
|
),
|
||||||
if len(output.completion_tokens) > i
|
spec_verify_ct=(
|
||||||
else None
|
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
||||||
),
|
),
|
||||||
cached_tokens=(
|
input_token_logprobs_val=(
|
||||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
[output.input_token_logprobs_val[i]]
|
||||||
),
|
if output.input_token_logprobs_val
|
||||||
spec_verify_ct=(
|
else None
|
||||||
[output.spec_verify_ct[i]]
|
),
|
||||||
if len(output.spec_verify_ct) > i
|
input_token_logprobs_idx=(
|
||||||
else None
|
[output.input_token_logprobs_idx[i]]
|
||||||
),
|
if output.input_token_logprobs_idx
|
||||||
input_token_logprobs_val=(
|
else None
|
||||||
[output.input_token_logprobs_val[i]]
|
),
|
||||||
if output.input_token_logprobs_val
|
output_token_logprobs_val=(
|
||||||
else None
|
[output.output_token_logprobs_val[i]]
|
||||||
),
|
if output.output_token_logprobs_val
|
||||||
input_token_logprobs_idx=(
|
else None
|
||||||
[output.input_token_logprobs_idx[i]]
|
),
|
||||||
if output.input_token_logprobs_idx
|
output_token_logprobs_idx=(
|
||||||
else None
|
[output.output_token_logprobs_idx[i]]
|
||||||
),
|
if output.output_token_logprobs_idx
|
||||||
output_token_logprobs_val=(
|
else None
|
||||||
[output.output_token_logprobs_val[i]]
|
),
|
||||||
if output.output_token_logprobs_val
|
input_top_logprobs_val=(
|
||||||
else None
|
[output.input_top_logprobs_val[i]]
|
||||||
),
|
if output.input_top_logprobs_val
|
||||||
output_token_logprobs_idx=(
|
else None
|
||||||
[output.output_token_logprobs_idx[i]]
|
),
|
||||||
if output.output_token_logprobs_idx
|
input_top_logprobs_idx=(
|
||||||
else None
|
[output.input_top_logprobs_idx[i]]
|
||||||
),
|
if output.input_top_logprobs_idx
|
||||||
input_top_logprobs_val=(
|
else None
|
||||||
[output.input_top_logprobs_val[i]]
|
),
|
||||||
if output.input_top_logprobs_val
|
output_top_logprobs_val=(
|
||||||
else None
|
[output.output_top_logprobs_val[i]]
|
||||||
),
|
if output.output_top_logprobs_val
|
||||||
input_top_logprobs_idx=(
|
else None
|
||||||
[output.input_top_logprobs_idx[i]]
|
),
|
||||||
if output.input_top_logprobs_idx
|
output_top_logprobs_idx=(
|
||||||
else None
|
[output.output_top_logprobs_idx[i]]
|
||||||
),
|
if output.output_top_logprobs_idx
|
||||||
output_top_logprobs_val=(
|
else None
|
||||||
[output.output_top_logprobs_val[i]]
|
),
|
||||||
if output.output_top_logprobs_val
|
input_token_ids_logprobs_val=(
|
||||||
else None
|
[output.input_token_ids_logprobs_val[i]]
|
||||||
),
|
if output.input_token_ids_logprobs_val
|
||||||
output_top_logprobs_idx=(
|
else None
|
||||||
[output.output_top_logprobs_idx[i]]
|
),
|
||||||
if output.output_top_logprobs_idx
|
input_token_ids_logprobs_idx=(
|
||||||
else None
|
[output.input_token_ids_logprobs_idx[i]]
|
||||||
),
|
if output.input_token_ids_logprobs_idx
|
||||||
input_token_ids_logprobs_val=(
|
else None
|
||||||
[output.input_token_ids_logprobs_val[i]]
|
),
|
||||||
if output.input_token_ids_logprobs_val
|
output_token_ids_logprobs_val=(
|
||||||
else None
|
[output.output_token_ids_logprobs_val[i]]
|
||||||
),
|
if output.output_token_ids_logprobs_val
|
||||||
input_token_ids_logprobs_idx=(
|
else None
|
||||||
[output.input_token_ids_logprobs_idx[i]]
|
),
|
||||||
if output.input_token_ids_logprobs_idx
|
output_token_ids_logprobs_idx=(
|
||||||
else None
|
[output.output_token_ids_logprobs_idx[i]]
|
||||||
),
|
if output.output_token_ids_logprobs_idx
|
||||||
output_token_ids_logprobs_val=(
|
else None
|
||||||
[output.output_token_ids_logprobs_val[i]]
|
),
|
||||||
if output.output_token_ids_logprobs_val
|
output_hidden_states=(
|
||||||
else None
|
[output.output_hidden_states[i]]
|
||||||
),
|
if output.output_hidden_states
|
||||||
output_token_ids_logprobs_idx=(
|
else None
|
||||||
[output.output_token_ids_logprobs_idx[i]]
|
),
|
||||||
if output.output_token_ids_logprobs_idx
|
)
|
||||||
else None
|
elif isinstance(output, BatchMultimodalOut):
|
||||||
),
|
new_output = BatchMultimodalOut(
|
||||||
output_hidden_states=(
|
rids=[output.rids[i]],
|
||||||
[output.output_hidden_states[i]]
|
finished_reasons=(
|
||||||
if output.output_hidden_states
|
[output.finished_reasons[i]]
|
||||||
else None
|
if len(output.finished_reasons) > i
|
||||||
),
|
else None
|
||||||
)
|
),
|
||||||
elif isinstance(output, BatchMultimodalOut):
|
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
||||||
new_output = BatchMultimodalOut(
|
prompt_tokens=(
|
||||||
rids=[output.rids[i]],
|
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
||||||
finished_reasons=(
|
),
|
||||||
[output.finished_reasons[i]]
|
completion_tokens=(
|
||||||
if len(output.finished_reasons) > i
|
[output.completion_tokens[i]]
|
||||||
else None
|
if len(output.completion_tokens) > i
|
||||||
),
|
else None
|
||||||
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
),
|
||||||
prompt_tokens=(
|
cached_tokens=(
|
||||||
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
||||||
),
|
),
|
||||||
completion_tokens=(
|
)
|
||||||
[output.completion_tokens[i]]
|
else:
|
||||||
if len(output.completion_tokens) > i
|
new_output = output
|
||||||
else None
|
return new_output
|
||||||
),
|
|
||||||
cached_tokens=(
|
|
||||||
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
class MultiHttpWorkerDetokenizerMixin:
|
||||||
),
|
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_output = output
|
|
||||||
return new_output
|
|
||||||
|
|
||||||
def get_worker_ids_from_req_rids(self, rids):
|
def get_worker_ids_from_req_rids(self, rids):
|
||||||
if isinstance(rids, list):
|
if isinstance(rids, list):
|
||||||
@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
|
|||||||
worker_ids = []
|
worker_ids = []
|
||||||
return worker_ids
|
return worker_ids
|
||||||
|
|
||||||
def multi_tokenizer_manager_event_loop(self):
|
def multi_http_worker_event_loop(self):
|
||||||
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
||||||
self.create_sockets_mapping()
|
self.socket_mapping = SocketMapping()
|
||||||
while True:
|
while True:
|
||||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||||
output = self._request_dispatcher(recv_obj)
|
output = self._request_dispatcher(recv_obj)
|
||||||
@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
|
|||||||
# Send data using the corresponding socket
|
# Send data using the corresponding socket
|
||||||
for i, worker_id in enumerate(worker_ids):
|
for i, worker_id in enumerate(worker_ids):
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||||
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
self.socket_mapping.register_ipc_mapping(
|
||||||
logger.info(
|
recv_obj, worker_id, is_tokenizer=False
|
||||||
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
)
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
if worker_id not in self.tokenizer_mapping:
|
new_output = _handle_output_by_index(output, i)
|
||||||
logger.error(
|
self.socket_mapping.send_output(worker_id, new_output)
|
||||||
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
new_output = self._handle_output_by_index(output, i)
|
|
||||||
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
|
||||||
|
|
||||||
def clear_tokenizer_mapping(self):
|
|
||||||
if hasattr(self, "tokenizer_mapping"):
|
|
||||||
for socket in self.tokenizer_mapping.values():
|
|
||||||
try:
|
|
||||||
socket.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to close socket: {e}")
|
|
||||||
self.tokenizer_mapping.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
class MultiTokenizerRouter:
|
||||||
"""A router to receive requests from MultiTokenizerManager"""
|
"""A router to receive requests from MultiTokenizerManager"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
|||||||
self._handle_task = asyncio.run_coroutine_threadsafe(
|
self._handle_task = asyncio.run_coroutine_threadsafe(
|
||||||
print_exception_wrapper(self.handle_loop), self._loop
|
print_exception_wrapper(self.handle_loop), self._loop
|
||||||
)
|
)
|
||||||
self.init_disaggregation()
|
self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
|
||||||
|
|
||||||
def _run_loop(self):
|
def _run_loop(self):
|
||||||
self._loop.run_forever()
|
self._loop.run_forever()
|
||||||
@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
# special reqs will recv from scheduler, need to route to right worker
|
# special reqs will recv from scheduler, need to route to right worker
|
||||||
self.create_sockets_mapping()
|
self.socket_mapping = SocketMapping()
|
||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
await self._distribute_result_to_workers(recv_obj)
|
await self._distribute_result_to_workers(recv_obj)
|
||||||
@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
|||||||
# Distribute result to each worker
|
# Distribute result to each worker
|
||||||
for i, worker_id in enumerate(worker_ids):
|
for i, worker_id in enumerate(worker_ids):
|
||||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||||
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
self.socket_mapping.register_ipc_mapping(
|
||||||
logger.info(
|
recv_obj, worker_id, is_tokenizer=True
|
||||||
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
|
)
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
if worker_id not in self.tokenizer_mapping:
|
new_recv_obj = _handle_output_by_index(recv_obj, i)
|
||||||
logger.error(
|
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
||||||
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
new_recv_obj = self._handle_output_by_index(recv_obj, i)
|
|
||||||
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
class MultiTokenizerManager(TokenizerManager):
|
||||||
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def serialize_port_args(port_args: PortArgs) -> dict:
|
def get_main_process_id() -> int:
|
||||||
"""Serialize PortArgs into a shareable dictionary"""
|
"""Get the main process ID"""
|
||||||
return {
|
return multiprocessing.current_process()._parent_pid
|
||||||
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
|
||||||
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
|
||||||
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
|
||||||
"nccl_port": port_args.nccl_port,
|
|
||||||
"rpc_ipc_name": port_args.rpc_ipc_name,
|
|
||||||
"metrics_ipc_name": port_args.metrics_ipc_name,
|
|
||||||
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_data(port_args: dict, server_args: dict):
|
def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
|
||||||
"""Deserialize data from shared dictionaries"""
|
|
||||||
return PortArgs(**port_args), ServerArgs(**server_args)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_server_args(server_args: ServerArgs) -> dict:
|
|
||||||
"""Serialize ServerArgs into a shareable dictionary"""
|
|
||||||
return dataclasses.asdict(server_args)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
|
||||||
"""Serialize scheduler_info into a shareable dictionary"""
|
|
||||||
return scheduler_info
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_scheduler_info(data: dict) -> Dict:
|
|
||||||
"""Deserialize scheduler_info from a shared dictionary"""
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
|
||||||
"""Write data to shared memory"""
|
"""Write data to shared memory"""
|
||||||
serialized = json.dumps(data).encode("utf-8")
|
serialized = pickle.dumps(obj)
|
||||||
size = len(serialized)
|
size = len(serialized)
|
||||||
try:
|
try:
|
||||||
# Try to open existing shared memory
|
# Try to open existing shared memory
|
||||||
@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
|||||||
return shm
|
return shm
|
||||||
|
|
||||||
|
|
||||||
def read_from_shared_memory(name: str) -> dict:
|
def read_from_shared_memory(name: str) -> Any:
|
||||||
"""Read data from shared memory"""
|
"""Read data from shared memory"""
|
||||||
try:
|
try:
|
||||||
shm = shared_memory.SharedMemory(name=name)
|
shm = shared_memory.SharedMemory(name=name)
|
||||||
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
data = pickle.loads(bytes(shm.buf))
|
||||||
shm.close()
|
shm.close()
|
||||||
return data
|
return data
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(f"Shared memory {name} not found")
|
raise FileNotFoundError(f"Shared memory {name} not found")
|
||||||
|
|
||||||
|
|
||||||
def get_main_process_id() -> int:
|
|
||||||
"""Get the main process ID"""
|
|
||||||
return multiprocessing.current_process()._parent_pid
|
|
||||||
|
|
||||||
|
|
||||||
def write_data_for_multi_tokenizer(
|
def write_data_for_multi_tokenizer(
|
||||||
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
||||||
):
|
):
|
||||||
@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
|
|||||||
main_pid = get_main_process_id()
|
main_pid = get_main_process_id()
|
||||||
current_pid = os.getpid()
|
current_pid = os.getpid()
|
||||||
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
||||||
|
args = (port_args, server_args, scheduler_info)
|
||||||
|
args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
|
||||||
|
args_shm.close()
|
||||||
|
|
||||||
# Write port_args to shared memory
|
return args_shm
|
||||||
port_args_shm = write_to_shared_memory(
|
|
||||||
serialize_port_args(port_args), f"port_args_{current_pid}"
|
|
||||||
)
|
|
||||||
# Write server_args to shared memory
|
|
||||||
server_args_shm = write_to_shared_memory(
|
|
||||||
serialize_server_args(server_args), f"server_args_{current_pid}"
|
|
||||||
)
|
|
||||||
# Write scheduler_info to shared memory
|
|
||||||
scheduler_info_shm = write_to_shared_memory(
|
|
||||||
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
|
||||||
)
|
|
||||||
|
|
||||||
port_args_shm.close()
|
|
||||||
server_args_shm.close()
|
|
||||||
scheduler_info_shm.close()
|
|
||||||
|
|
||||||
return port_args_shm, server_args_shm, scheduler_info_shm
|
|
||||||
|
|||||||
@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
|
|||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
from sglang.srt.aio_rwlock import RWLock
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.disaggregation.utils import (
|
|
||||||
DisaggregationMode,
|
|
||||||
KVClassType,
|
|
||||||
TransferBackend,
|
|
||||||
get_kv_class,
|
|
||||||
)
|
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.hf_transformers_utils import (
|
||||||
get_processor,
|
get_processor,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_tokenizer_from_processor,
|
get_tokenizer_from_processor,
|
||||||
)
|
)
|
||||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||||
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
@@ -321,8 +316,10 @@ class TokenizerManager:
|
|||||||
# LoRA updates and inference to overlap.
|
# LoRA updates and inference to overlap.
|
||||||
self.lora_update_lock = asyncio.Lock()
|
self.lora_update_lock = asyncio.Lock()
|
||||||
|
|
||||||
# For PD disaggregtion
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.init_disaggregation()
|
self.server_args.disaggregation_mode
|
||||||
|
)
|
||||||
|
self.bootstrap_server = start_disagg_service(self.server_args)
|
||||||
|
|
||||||
# For load balancing
|
# For load balancing
|
||||||
self.current_load = 0
|
self.current_load = 0
|
||||||
@@ -471,38 +468,6 @@ class TokenizerManager:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_disaggregation(self):
|
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
|
||||||
self.server_args.disaggregation_mode
|
|
||||||
)
|
|
||||||
self.disaggregation_transfer_backend = TransferBackend(
|
|
||||||
self.server_args.disaggregation_transfer_backend
|
|
||||||
)
|
|
||||||
# Start kv boostrap server on prefill
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
||||||
# only start bootstrap server on prefill tm
|
|
||||||
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
|
||||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
|
||||||
)
|
|
||||||
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
|
||||||
host=self.server_args.host,
|
|
||||||
port=self.server_args.disaggregation_bootstrap_port,
|
|
||||||
)
|
|
||||||
is_create_store = (
|
|
||||||
self.server_args.node_rank == 0
|
|
||||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
|
||||||
)
|
|
||||||
if is_create_store:
|
|
||||||
try:
|
|
||||||
from mf_adapter import create_config_store
|
|
||||||
|
|
||||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
|
||||||
create_config_store(ascend_url)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f"Failed create mf store, invalid ascend_url."
|
|
||||||
error_message += f" With exception {e}"
|
|
||||||
raise error_message
|
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
|||||||
Reference in New Issue
Block a user