Support Multi Process Tokenizer Manager (#6555)

Signed-off-by: ybyang <ybyang7@iflytek.com>
Signed-off-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: lw9527 <952799980@qq.com>
Co-authored-by: huanglong <huanglong@linux.alibaba.com>
Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
This commit is contained in:
ybyang
2025-08-08 16:45:50 +08:00
committed by GitHub
parent 6ee6619b7a
commit 7490e3f67d
9 changed files with 1133 additions and 73 deletions

View File

@@ -89,6 +89,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerRegisterReq,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -124,6 +125,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
dataclass_to_string_truncated,
get_bool_env_var,
get_origin_rid,
get_workerids_from_rids,
get_zmq_socket,
kill_process_tree,
)
@@ -171,6 +174,9 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
_global_tokenizer_worker_num = 1
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
@@ -178,6 +184,7 @@ class TokenizerManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
is_main: Optional[bool] = True,
):
# Parse args
self.server_args = server_args
@@ -191,6 +198,9 @@ class TokenizerManager:
)
self.crash_dump_folder = server_args.crash_dump_folder
self.is_main = is_main
self.worker_id = os.getpid()
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
@@ -255,13 +265,41 @@ class TokenizerManager:
)
# Init inter-process communication
context = zmq.asyncio.Context(2)
context = zmq.asyncio.Context(3)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
global _global_tokenizer_worker_num
_global_tokenizer_worker_num = server_args.tokenizer_worker_num
if server_args.tokenizer_worker_num > 1:
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
if self.is_main:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
self.receive_from_worker = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
)
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._task = asyncio.run_coroutine_threadsafe(
self.router_worker_obj(), self._loop
)
# Start handle_loop simultaneously
self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop
)
else:
# actual send to main receiver_from_worker
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states
self.no_create_loop = False
@@ -315,26 +353,27 @@ class TokenizerManager:
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
if self.is_main:
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
# For load balancing
self.current_load = 0
@@ -467,6 +506,14 @@ class TokenizerManager:
]
)
def _run_loop(self):
self._loop.run_forever()
async def router_worker_obj(self):
while True:
recv_obj = await self.receive_from_worker.recv_pyobj()
await self.send_to_scheduler.send_pyobj(recv_obj)
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -479,6 +526,15 @@ class TokenizerManager:
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
@@ -1505,12 +1561,378 @@ class TokenizerManager:
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
# In multi-worker mode, distribute results to corresponding workers
if self.server_args.tokenizer_worker_num > 1 and self.is_main:
await self._distribute_result_to_workers(recv_obj)
else:
# In single worker mode, process directly
self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.time()
def init_tokenizer_mapping(self, recv_obj: MultiTokenizerRegisterReq):
"""init tokenizer mapping from register request"""
if isinstance(recv_obj.rids, list):
worker_ids = get_workerids_from_rids(recv_obj.rids)
else:
raise RuntimeError(f"tokenizer_worker_num > 1, recv_obj.rids must be list")
for worker_id in worker_ids:
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
logger.info(
f"Main Tokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
)
else:
logger.info(
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
)
async def _distribute_result_to_workers(self, recv_obj):
"""Distribute result to corresponding workers based on rid"""
worker_ids = get_workerids_from_rids(recv_obj.rids)
if len(worker_ids) == 0:
self._result_dispatcher(recv_obj)
return
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
# Distribute result to each worker
for i, worker_id in enumerate(worker_ids):
if worker_id not in self.tokenizer_mapping:
if isinstance(recv_obj, MultiTokenizerRegisterReq):
self.init_tokenizer_mapping(recv_obj)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
continue
else:
if isinstance(recv_obj, MultiTokenizerRegisterReq):
continue
if not isinstance(
recv_obj,
(
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
BatchMultimodalOut,
),
):
# Send to worker
self.tokenizer_mapping[worker_id].send_pyobj(recv_obj)
else:
if isinstance(recv_obj, BatchTokenIDOut):
new_recv_obj = BatchTokenIDOut(
[recv_obj.rids[i]],
(
[recv_obj.finished_reasons[i]]
if len(recv_obj.finished_reasons) > i
else None
),
(
[recv_obj.decoded_texts[i]]
if len(recv_obj.decoded_texts) > i
else None
),
(
[recv_obj.decode_ids[i]]
if len(recv_obj.decode_ids) > i
else None
),
(
[recv_obj.read_offsets[i]]
if len(recv_obj.read_offsets) > i
else None
),
(
[recv_obj.output_ids[i]]
if recv_obj.output_ids and len(recv_obj.output_ids) > i
else None
),
(
[recv_obj.skip_special_tokens[i]]
if len(recv_obj.skip_special_tokens) > i
else None
),
(
[recv_obj.spaces_between_special_tokens[i]]
if len(recv_obj.spaces_between_special_tokens) > i
else None
),
(
[recv_obj.no_stop_trim[i]]
if len(recv_obj.no_stop_trim) > i
else None
),
(
[recv_obj.prompt_tokens[i]]
if len(recv_obj.prompt_tokens) > i
else None
),
(
[recv_obj.completion_tokens[i]]
if len(recv_obj.completion_tokens) > i
else None
),
(
[recv_obj.cached_tokens[i]]
if len(recv_obj.cached_tokens) > i
else None
),
(
[recv_obj.spec_verify_ct[i]]
if len(recv_obj.spec_verify_ct) > i
else None
),
(
[recv_obj.input_token_logprobs_val[i]]
if recv_obj.input_token_logprobs_val
else None
),
(
[recv_obj.input_token_logprobs_idx[i]]
if recv_obj.input_token_logprobs_idx
else None
),
(
[recv_obj.output_token_logprobs_val[i]]
if recv_obj.output_token_logprobs_val
else None
),
(
[recv_obj.output_token_logprobs_idx[i]]
if recv_obj.output_token_logprobs_idx
else None
),
(
[recv_obj.input_top_logprobs_val[i]]
if recv_obj.input_top_logprobs_val
else None
),
(
[recv_obj.input_top_logprobs_idx[i]]
if recv_obj.input_top_logprobs_idx
else None
),
(
[recv_obj.output_top_logprobs_val[i]]
if recv_obj.output_top_logprobs_val
else None
),
(
[recv_obj.output_top_logprobs_idx[i]]
if recv_obj.output_top_logprobs_idx
else None
),
(
[recv_obj.input_token_ids_logprobs_val[i]]
if recv_obj.input_token_ids_logprobs_val
else None
),
(
[recv_obj.input_token_ids_logprobs_idx[i]]
if recv_obj.input_token_ids_logprobs_idx
else None
),
(
[recv_obj.output_token_ids_logprobs_val[i]]
if recv_obj.output_token_ids_logprobs_val
else None
),
(
[recv_obj.output_token_ids_logprobs_idx[i]]
if recv_obj.output_token_ids_logprobs_idx
else None
),
(
[recv_obj.output_hidden_states[i]]
if recv_obj.output_hidden_states
else None
),
)
elif isinstance(recv_obj, BatchEmbeddingOut):
new_recv_obj = BatchEmbeddingOut(
[recv_obj.rids[i]],
(
[recv_obj.finished_reasons[i]]
if len(recv_obj.finished_reasons) > i
else None
),
(
[recv_obj.embeddings[i]]
if len(recv_obj.embeddings) > i
else None
),
(
[recv_obj.prompt_tokens[i]]
if len(recv_obj.prompt_tokens) > i
else None
),
(
[recv_obj.cached_tokens[i]]
if len(recv_obj.cached_tokens) > i
else None
),
)
elif isinstance(recv_obj, BatchStrOut):
new_recv_obj = BatchStrOut(
[recv_obj.rids[i]],
(
[recv_obj.finished_reasons[i]]
if len(recv_obj.finished_reasons) > i
else None
),
(
[recv_obj.output_strs[i]]
if len(recv_obj.output_strs) > i
else None
),
(
[recv_obj.output_ids[i]]
if recv_obj.output_ids and len(recv_obj.output_ids) > i
else None
),
(
[recv_obj.prompt_tokens[i]]
if len(recv_obj.prompt_tokens) > i
else None
),
(
[recv_obj.completion_tokens[i]]
if len(recv_obj.completion_tokens) > i
else None
),
(
[recv_obj.cached_tokens[i]]
if len(recv_obj.cached_tokens) > i
else None
),
(
[recv_obj.spec_verify_ct[i]]
if len(recv_obj.spec_verify_ct) > i
else None
),
(
[recv_obj.input_token_logprobs_val[i]]
if recv_obj.input_token_logprobs_val
else None
),
(
[recv_obj.input_token_logprobs_idx[i]]
if recv_obj.input_token_logprobs_idx
else None
),
(
[recv_obj.output_token_logprobs_val[i]]
if recv_obj.output_token_logprobs_val
else None
),
(
[recv_obj.output_token_logprobs_idx[i]]
if recv_obj.output_token_logprobs_idx
else None
),
(
[recv_obj.input_top_logprobs_val[i]]
if recv_obj.input_top_logprobs_val
else None
),
(
[recv_obj.input_top_logprobs_idx[i]]
if recv_obj.input_top_logprobs_idx
else None
),
(
[recv_obj.output_top_logprobs_val[i]]
if recv_obj.output_top_logprobs_val
else None
),
(
[recv_obj.output_top_logprobs_idx[i]]
if recv_obj.output_top_logprobs_idx
else None
),
(
[recv_obj.input_token_ids_logprobs_val[i]]
if recv_obj.input_token_ids_logprobs_val
else None
),
(
[recv_obj.input_token_ids_logprobs_idx[i]]
if recv_obj.input_token_ids_logprobs_idx
else None
),
(
[recv_obj.output_token_ids_logprobs_val[i]]
if recv_obj.output_token_ids_logprobs_val
else None
),
(
[recv_obj.output_token_ids_logprobs_idx[i]]
if recv_obj.output_token_ids_logprobs_idx
else None
),
(
[recv_obj.output_hidden_states[i]]
if recv_obj.output_hidden_states
else None
),
)
elif isinstance(recv_obj, BatchMultimodalOut):
new_recv_obj = BatchMultimodalOut(
[recv_obj.rids[i]],
(
[recv_obj.finished_reasons[i]]
if len(recv_obj.finished_reasons) > i
else None
),
([recv_obj.outputs[i]] if len(recv_obj.outputs) > i else None),
(
[recv_obj.prompt_tokens[i]]
if len(recv_obj.prompt_tokens) > i
else None
),
(
[recv_obj.completion_tokens[i]]
if len(recv_obj.completion_tokens) > i
else None
),
(
[recv_obj.cached_tokens[i]]
if len(recv_obj.cached_tokens) > i
else None
),
)
try:
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
except zmq.ZMQError as e:
raise RuntimeError(
f"Failed to send result to worker {worker_id}: {e}"
) from e
def register_to_main_tokenizer_manager(self):
"""Register this worker to the main TokenizerManager"""
req = MultiTokenizerRegisterReq()
req.rids = [f"{self.worker_id}_registertokenizer"]
req.ipc_name = self.tokenizer_ipc_name
self.send_to_scheduler.send_pyobj(req)
time.sleep(5)
def _handle_batch_output(
self,
recv_obj: Union[
@@ -1524,10 +1946,12 @@ class TokenizerManager:
f"Received output for {rid=} but the state was deleted in TokenizerManager."
)
continue
originRid = rid
if self.server_args.tokenizer_worker_num > 1:
originRid = get_origin_rid(rid)
# Build meta_info and return value
meta_info = {
"id": rid,
"id": originRid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
@@ -1828,6 +2252,9 @@ class TokenizerManager:
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
rid = get_origin_rid(rid)
state.finished = True
if recv_obj.finished_reason:
out = {
@@ -1840,7 +2267,7 @@ class TokenizerManager:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
"id": rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
@@ -2029,6 +2456,7 @@ class _Communicator(Generic[T]):
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
global _global_tokenizer_worker_num
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
@@ -2037,6 +2465,14 @@ class _Communicator(Generic[T]):
assert self._result_values is None
if obj:
if _global_tokenizer_worker_num > 1:
if obj.rids is None:
obj.rids = f"{os.getpid()}_{uuid.uuid4().hex}_Communicator"
else:
if isinstance(obj.rids, str):
obj.rids = f"{os.getpid()}_{obj.rids}"
elif isinstance(obj.rids, list):
obj.rids = [f"{os.getpid()}_{rid}" for rid in obj.rids]
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
@@ -2051,6 +2487,19 @@ class _Communicator(Generic[T]):
return result_values
def handle_recv(self, recv_obj: T):
global _global_tokenizer_worker_num
if _global_tokenizer_worker_num > 1:
# If rids is a string and not empty, remove the prefix
if (
hasattr(recv_obj, "rids")
and isinstance(recv_obj.rids, str)
and recv_obj.rids
):
recv_obj.rids = get_origin_rid(recv_obj.rids)
# If rids is a list, remove prefix from each element
elif hasattr(recv_obj, "rids") and isinstance(recv_obj.rids, list):
recv_obj.rids = [get_origin_rid(rid) for rid in recv_obj.rids]
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_event.set()