Support Multi Process Tokenizer Manager (#6555)
Signed-off-by: ybyang <ybyang7@iflytek.com> Signed-off-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: lw9527 <952799980@qq.com> Co-authored-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
This commit is contained in:
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
MultiTokenizerRegisterReq,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_workerids_from_rids,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
@@ -81,7 +83,6 @@ class DetokenizerManager:
|
||||
self.send_to_tokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
@@ -94,21 +95,208 @@ class DetokenizerManager:
|
||||
|
||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||
self.is_dummy = server_args.load_format == "dummy"
|
||||
|
||||
self.tokenizer_worker_num = server_args.tokenizer_worker_num
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
(MultiTokenizerRegisterReq, lambda x: None),
|
||||
]
|
||||
)
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
try:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
if self.tokenizer_worker_num <= 1:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
else:
|
||||
# Extract worker_id from rid
|
||||
if isinstance(recv_obj.rids, list):
|
||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"tokenizer_worker_num > 1, recv_obj.rids must be list"
|
||||
)
|
||||
|
||||
if not hasattr(self, "tokenizer_mapping"):
|
||||
self.tokenizer_mapping = {}
|
||||
|
||||
# Create ZMQ context if needed
|
||||
if not hasattr(self, "_zmq_context"):
|
||||
self._zmq_context = zmq.Context()
|
||||
|
||||
# Send data using the corresponding socket
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if worker_id not in self.tokenizer_mapping:
|
||||
# register the worker if not already done
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
self.init_tokenizer_mapping(recv_obj, worker_id)
|
||||
else:
|
||||
logger.error(
|
||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
||||
"Please ensure the worker is registered correctly."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
continue
|
||||
|
||||
# Create a new output object based on the type
|
||||
if isinstance(output, BatchEmbeddingOut):
|
||||
new_output = BatchEmbeddingOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=[output.finished_reasons[i]],
|
||||
embeddings=[output.embeddings[i]],
|
||||
prompt_tokens=[output.prompt_tokens[i]],
|
||||
cached_tokens=[output.cached_tokens[i]],
|
||||
)
|
||||
elif isinstance(output, BatchStrOut):
|
||||
new_output = BatchStrOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=(
|
||||
[output.finished_reasons[i]]
|
||||
if len(output.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
output_strs=(
|
||||
[output.output_strs[i]]
|
||||
if len(output.output_strs) > i
|
||||
else None
|
||||
),
|
||||
output_ids=(
|
||||
[output.output_ids[i]]
|
||||
if output.output_ids and len(output.output_ids) > i
|
||||
else None
|
||||
),
|
||||
prompt_tokens=(
|
||||
[output.prompt_tokens[i]]
|
||||
if len(output.prompt_tokens) > i
|
||||
else None
|
||||
),
|
||||
completion_tokens=(
|
||||
[output.completion_tokens[i]]
|
||||
if len(output.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
cached_tokens=(
|
||||
[output.cached_tokens[i]]
|
||||
if len(output.cached_tokens) > i
|
||||
else None
|
||||
),
|
||||
spec_verify_ct=(
|
||||
[output.spec_verify_ct[i]]
|
||||
if len(output.spec_verify_ct) > i
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_val=(
|
||||
[output.input_token_logprobs_val[i]]
|
||||
if output.input_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_logprobs_idx=(
|
||||
[output.input_token_logprobs_idx[i]]
|
||||
if output.input_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_val=(
|
||||
[output.output_token_logprobs_val[i]]
|
||||
if output.output_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_logprobs_idx=(
|
||||
[output.output_token_logprobs_idx[i]]
|
||||
if output.output_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_val=(
|
||||
[output.input_top_logprobs_val[i]]
|
||||
if output.input_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_top_logprobs_idx=(
|
||||
[output.input_top_logprobs_idx[i]]
|
||||
if output.input_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_val=(
|
||||
[output.output_top_logprobs_val[i]]
|
||||
if output.output_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_top_logprobs_idx=(
|
||||
[output.output_top_logprobs_idx[i]]
|
||||
if output.output_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_val=(
|
||||
[output.input_token_ids_logprobs_val[i]]
|
||||
if output.input_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
input_token_ids_logprobs_idx=(
|
||||
[output.input_token_ids_logprobs_idx[i]]
|
||||
if output.input_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_val=(
|
||||
[output.output_token_ids_logprobs_val[i]]
|
||||
if output.output_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
output_token_ids_logprobs_idx=(
|
||||
[output.output_token_ids_logprobs_idx[i]]
|
||||
if output.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_hidden_states=(
|
||||
[output.output_hidden_states[i]]
|
||||
if output.output_hidden_states
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(output, BatchMultimodalOut):
|
||||
new_output = BatchMultimodalOut(
|
||||
rids=[output.rids[i]],
|
||||
finished_reasons=[output.finished_reasons[i]],
|
||||
prompt_tokens=[output.prompt_tokens[i]],
|
||||
completion_tokens=[output.completion_tokens[i]],
|
||||
cached_tokens=[output.cached_tokens[i]],
|
||||
)
|
||||
else:
|
||||
new_output = output
|
||||
|
||||
try:
|
||||
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
||||
except zmq.error.ZMQError as e:
|
||||
logger.info(
|
||||
f"ZMQ error when sending to worker {worker_id}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detokenizer event loop: {e}")
|
||||
raise e
|
||||
|
||||
def init_tokenizer_mapping(
|
||||
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
||||
):
|
||||
"""init tokenizer mapping from register request"""
|
||||
ipc_name = recv_obj.ipc_name
|
||||
worker_id_int = int(worker_id)
|
||||
|
||||
if worker_id_int not in self.tokenizer_mapping:
|
||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
||||
self.tokenizer_mapping[worker_id_int] = socket
|
||||
logger.info(
|
||||
f"Detokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
|
||||
)
|
||||
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
|
||||
Reference in New Issue
Block a user