Support Multi Process Tokenizer Manager(#6555) (#8964)

Signed-off-by: ybyang <ybyang7@iflytek.com>
Signed-off-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
Co-authored-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
ybyang
2025-09-01 16:00:13 +08:00
committed by GitHub
parent 4750cddf68
commit 5f77e1292d
11 changed files with 1030 additions and 80 deletions

View File

@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWarpper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -131,6 +132,7 @@ from sglang.srt.utils import (
dataclass_to_string_truncated,
freeze_gc,
get_bool_env_var,
get_origin_rid,
get_zmq_socket,
kill_process_tree,
)
@@ -266,9 +268,15 @@ class TokenizerManager:
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
if self.server_args.tokenizer_worker_num > 1:
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states
self.no_create_loop = False
@@ -312,35 +320,7 @@ class TokenizerManager:
self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
self.init_disaggregation()
# For load balancing
self.current_load = 0
@@ -488,6 +468,37 @@ class TokenizerManager:
]
)
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -497,6 +508,15 @@ class TokenizerManager:
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
@@ -1096,6 +1116,8 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
@@ -1315,6 +1337,8 @@ class TokenizerManager:
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
@@ -1590,7 +1614,6 @@ class TokenizerManager:
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
@@ -1610,9 +1633,12 @@ class TokenizerManager:
)
continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value
meta_info = {
"id": rid,
"id": origin_rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
@@ -1918,6 +1944,9 @@ class TokenizerManager:
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True
if recv_obj.finished_reason:
out = {
@@ -1930,7 +1959,7 @@ class TokenizerManager:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
"id": origin_rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
@@ -2116,6 +2145,8 @@ T = TypeVar("T")
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()