Support Multi Process Tokenizer Manager (#6555)
Signed-off-by: ybyang <ybyang7@iflytek.com> Signed-off-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: lw9527 <952799980@qq.com> Co-authored-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com>
This commit is contained in:
@@ -89,6 +89,7 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerRegisterReq,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -124,6 +125,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
get_bool_env_var,
|
||||
get_origin_rid,
|
||||
get_workerids_from_rids,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
@@ -171,6 +174,9 @@ class ReqState:
|
||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
_global_tokenizer_worker_num = 1
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
@@ -178,6 +184,7 @@ class TokenizerManager:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
is_main: Optional[bool] = True,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
@@ -191,6 +198,9 @@ class TokenizerManager:
|
||||
)
|
||||
self.crash_dump_folder = server_args.crash_dump_folder
|
||||
|
||||
self.is_main = is_main
|
||||
self.worker_id = os.getpid()
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
@@ -255,13 +265,41 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
context = zmq.asyncio.Context(3)
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
global _global_tokenizer_worker_num
|
||||
_global_tokenizer_worker_num = server_args.tokenizer_worker_num
|
||||
if server_args.tokenizer_worker_num > 1:
|
||||
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||
if self.is_main:
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
self.receive_from_worker = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
||||
)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._task = asyncio.run_coroutine_threadsafe(
|
||||
self.router_worker_obj(), self._loop
|
||||
)
|
||||
# Start handle_loop simultaneously
|
||||
self._handle_task = asyncio.run_coroutine_threadsafe(
|
||||
print_exception_wrapper(self.handle_loop), self._loop
|
||||
)
|
||||
|
||||
else:
|
||||
# actual send to main receiver_from_worker
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
||||
)
|
||||
else:
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Request states
|
||||
self.no_create_loop = False
|
||||
@@ -315,26 +353,27 @@ class TokenizerManager:
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
if self.is_main:
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
@@ -467,6 +506,14 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
def _run_loop(self):
|
||||
self._loop.run_forever()
|
||||
|
||||
async def router_worker_obj(self):
|
||||
while True:
|
||||
recv_obj = await self.receive_from_worker.recv_pyobj()
|
||||
await self.send_to_scheduler.send_pyobj(recv_obj)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -479,6 +526,15 @@ class TokenizerManager:
|
||||
async with self._is_updating_cond:
|
||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Modify rid, add worker_id
|
||||
if isinstance(obj.rid, list):
|
||||
# If it's an array, add worker_id prefix to each element
|
||||
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
||||
else:
|
||||
# If it's a single value, add worker_id prefix
|
||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
@@ -1505,12 +1561,378 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
# In multi-worker mode, distribute results to corresponding workers
|
||||
if self.server_args.tokenizer_worker_num > 1 and self.is_main:
|
||||
await self._distribute_result_to_workers(recv_obj)
|
||||
else:
|
||||
# In single worker mode, process directly
|
||||
self._result_dispatcher(recv_obj)
|
||||
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
def init_tokenizer_mapping(self, recv_obj: MultiTokenizerRegisterReq):
|
||||
"""init tokenizer mapping from register request"""
|
||||
if isinstance(recv_obj.rids, list):
|
||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
||||
else:
|
||||
raise RuntimeError(f"tokenizer_worker_num > 1, recv_obj.rids must be list")
|
||||
|
||||
for worker_id in worker_ids:
|
||||
ipc_name = recv_obj.ipc_name
|
||||
worker_id_int = int(worker_id)
|
||||
|
||||
if worker_id_int not in self.tokenizer_mapping:
|
||||
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
||||
self.tokenizer_mapping[worker_id_int] = socket
|
||||
logger.info(
|
||||
f"Main Tokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"ZMQ socket for worker {worker_id} already exists, skipping creation"
|
||||
)
|
||||
|
||||
async def _distribute_result_to_workers(self, recv_obj):
|
||||
"""Distribute result to corresponding workers based on rid"""
|
||||
|
||||
worker_ids = get_workerids_from_rids(recv_obj.rids)
|
||||
if len(worker_ids) == 0:
|
||||
self._result_dispatcher(recv_obj)
|
||||
return
|
||||
|
||||
if not hasattr(self, "tokenizer_mapping"):
|
||||
self.tokenizer_mapping = {}
|
||||
|
||||
# Create ZMQ context if needed
|
||||
if not hasattr(self, "_zmq_context"):
|
||||
self._zmq_context = zmq.Context()
|
||||
|
||||
# Distribute result to each worker
|
||||
for i, worker_id in enumerate(worker_ids):
|
||||
if worker_id not in self.tokenizer_mapping:
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
self.init_tokenizer_mapping(recv_obj)
|
||||
else:
|
||||
logger.error(
|
||||
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
||||
"Please ensure the worker is registered correctly."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
||||
continue
|
||||
|
||||
if not isinstance(
|
||||
recv_obj,
|
||||
(
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
BatchMultimodalOut,
|
||||
),
|
||||
):
|
||||
# Send to worker
|
||||
self.tokenizer_mapping[worker_id].send_pyobj(recv_obj)
|
||||
else:
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
new_recv_obj = BatchTokenIDOut(
|
||||
[recv_obj.rids[i]],
|
||||
(
|
||||
[recv_obj.finished_reasons[i]]
|
||||
if len(recv_obj.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.decoded_texts[i]]
|
||||
if len(recv_obj.decoded_texts) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.decode_ids[i]]
|
||||
if len(recv_obj.decode_ids) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.read_offsets[i]]
|
||||
if len(recv_obj.read_offsets) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_ids[i]]
|
||||
if recv_obj.output_ids and len(recv_obj.output_ids) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.skip_special_tokens[i]]
|
||||
if len(recv_obj.skip_special_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.spaces_between_special_tokens[i]]
|
||||
if len(recv_obj.spaces_between_special_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.no_stop_trim[i]]
|
||||
if len(recv_obj.no_stop_trim) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.prompt_tokens[i]]
|
||||
if len(recv_obj.prompt_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.completion_tokens[i]]
|
||||
if len(recv_obj.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.cached_tokens[i]]
|
||||
if len(recv_obj.cached_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.spec_verify_ct[i]]
|
||||
if len(recv_obj.spec_verify_ct) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_logprobs_val[i]]
|
||||
if recv_obj.input_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_logprobs_idx[i]]
|
||||
if recv_obj.input_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_logprobs_val[i]]
|
||||
if recv_obj.output_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_logprobs_idx[i]]
|
||||
if recv_obj.output_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_top_logprobs_val[i]]
|
||||
if recv_obj.input_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_top_logprobs_idx[i]]
|
||||
if recv_obj.input_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_top_logprobs_val[i]]
|
||||
if recv_obj.output_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_top_logprobs_idx[i]]
|
||||
if recv_obj.output_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_ids_logprobs_val[i]]
|
||||
if recv_obj.input_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_ids_logprobs_idx[i]]
|
||||
if recv_obj.input_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_ids_logprobs_val[i]]
|
||||
if recv_obj.output_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_ids_logprobs_idx[i]]
|
||||
if recv_obj.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_hidden_states[i]]
|
||||
if recv_obj.output_hidden_states
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||
new_recv_obj = BatchEmbeddingOut(
|
||||
[recv_obj.rids[i]],
|
||||
(
|
||||
[recv_obj.finished_reasons[i]]
|
||||
if len(recv_obj.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.embeddings[i]]
|
||||
if len(recv_obj.embeddings) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.prompt_tokens[i]]
|
||||
if len(recv_obj.prompt_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.cached_tokens[i]]
|
||||
if len(recv_obj.cached_tokens) > i
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(recv_obj, BatchStrOut):
|
||||
new_recv_obj = BatchStrOut(
|
||||
[recv_obj.rids[i]],
|
||||
(
|
||||
[recv_obj.finished_reasons[i]]
|
||||
if len(recv_obj.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_strs[i]]
|
||||
if len(recv_obj.output_strs) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_ids[i]]
|
||||
if recv_obj.output_ids and len(recv_obj.output_ids) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.prompt_tokens[i]]
|
||||
if len(recv_obj.prompt_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.completion_tokens[i]]
|
||||
if len(recv_obj.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.cached_tokens[i]]
|
||||
if len(recv_obj.cached_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.spec_verify_ct[i]]
|
||||
if len(recv_obj.spec_verify_ct) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_logprobs_val[i]]
|
||||
if recv_obj.input_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_logprobs_idx[i]]
|
||||
if recv_obj.input_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_logprobs_val[i]]
|
||||
if recv_obj.output_token_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_logprobs_idx[i]]
|
||||
if recv_obj.output_token_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_top_logprobs_val[i]]
|
||||
if recv_obj.input_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_top_logprobs_idx[i]]
|
||||
if recv_obj.input_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_top_logprobs_val[i]]
|
||||
if recv_obj.output_top_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_top_logprobs_idx[i]]
|
||||
if recv_obj.output_top_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_ids_logprobs_val[i]]
|
||||
if recv_obj.input_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.input_token_ids_logprobs_idx[i]]
|
||||
if recv_obj.input_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_ids_logprobs_val[i]]
|
||||
if recv_obj.output_token_ids_logprobs_val
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_token_ids_logprobs_idx[i]]
|
||||
if recv_obj.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.output_hidden_states[i]]
|
||||
if recv_obj.output_hidden_states
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
new_recv_obj = BatchMultimodalOut(
|
||||
[recv_obj.rids[i]],
|
||||
(
|
||||
[recv_obj.finished_reasons[i]]
|
||||
if len(recv_obj.finished_reasons) > i
|
||||
else None
|
||||
),
|
||||
([recv_obj.outputs[i]] if len(recv_obj.outputs) > i else None),
|
||||
(
|
||||
[recv_obj.prompt_tokens[i]]
|
||||
if len(recv_obj.prompt_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.completion_tokens[i]]
|
||||
if len(recv_obj.completion_tokens) > i
|
||||
else None
|
||||
),
|
||||
(
|
||||
[recv_obj.cached_tokens[i]]
|
||||
if len(recv_obj.cached_tokens) > i
|
||||
else None
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
||||
except zmq.ZMQError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to send result to worker {worker_id}: {e}"
|
||||
) from e
|
||||
|
||||
def register_to_main_tokenizer_manager(self):
|
||||
"""Register this worker to the main TokenizerManager"""
|
||||
req = MultiTokenizerRegisterReq()
|
||||
req.rids = [f"{self.worker_id}_registertokenizer"]
|
||||
req.ipc_name = self.tokenizer_ipc_name
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
time.sleep(5)
|
||||
|
||||
def _handle_batch_output(
|
||||
self,
|
||||
recv_obj: Union[
|
||||
@@ -1524,10 +1946,12 @@ class TokenizerManager:
|
||||
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
||||
)
|
||||
continue
|
||||
|
||||
originRid = rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
originRid = get_origin_rid(rid)
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"id": originRid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
}
|
||||
@@ -1828,6 +2252,9 @@ class TokenizerManager:
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
rid = recv_obj.rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
rid = get_origin_rid(rid)
|
||||
state.finished = True
|
||||
if recv_obj.finished_reason:
|
||||
out = {
|
||||
@@ -1840,7 +2267,7 @@ class TokenizerManager:
|
||||
out = {
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"id": rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
@@ -2029,6 +2456,7 @@ class _Communicator(Generic[T]):
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
global _global_tokenizer_worker_num
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
@@ -2037,6 +2465,14 @@ class _Communicator(Generic[T]):
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _global_tokenizer_worker_num > 1:
|
||||
if obj.rids is None:
|
||||
obj.rids = f"{os.getpid()}_{uuid.uuid4().hex}_Communicator"
|
||||
else:
|
||||
if isinstance(obj.rids, str):
|
||||
obj.rids = f"{os.getpid()}_{obj.rids}"
|
||||
elif isinstance(obj.rids, list):
|
||||
obj.rids = [f"{os.getpid()}_{rid}" for rid in obj.rids]
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
@@ -2051,6 +2487,19 @@ class _Communicator(Generic[T]):
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
global _global_tokenizer_worker_num
|
||||
if _global_tokenizer_worker_num > 1:
|
||||
# If rids is a string and not empty, remove the prefix
|
||||
if (
|
||||
hasattr(recv_obj, "rids")
|
||||
and isinstance(recv_obj.rids, str)
|
||||
and recv_obj.rids
|
||||
):
|
||||
recv_obj.rids = get_origin_rid(recv_obj.rids)
|
||||
# If rids is a list, remove prefix from each element
|
||||
elif hasattr(recv_obj, "rids") and isinstance(recv_obj.rids, list):
|
||||
recv_obj.rids = [get_origin_rid(rid) for rid in recv_obj.rids]
|
||||
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_event.set()
|
||||
|
||||
Reference in New Issue
Block a user