|
|
|
|
@@ -1,480 +0,0 @@
|
|
|
|
|
# Copyright 2023-2024 SGLang Team
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import signal
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
import fastapi
|
|
|
|
|
import uvloop
|
|
|
|
|
import zmq
|
|
|
|
|
import zmq.asyncio
|
|
|
|
|
from fastapi import BackgroundTasks
|
|
|
|
|
|
|
|
|
|
from sglang.srt.aio_rwlock import RWLock
|
|
|
|
|
from sglang.srt.managers.generation_manager import GenerationManager
|
|
|
|
|
from sglang.srt.managers.io_struct import (
|
|
|
|
|
BatchEmbeddingOut,
|
|
|
|
|
BatchStrOut,
|
|
|
|
|
BatchTokenIDOut,
|
|
|
|
|
CloseSessionReqInput,
|
|
|
|
|
ConfigureLoggingReq,
|
|
|
|
|
EmbeddingReqInput,
|
|
|
|
|
FlushCacheReq,
|
|
|
|
|
GenerateReqInput,
|
|
|
|
|
GetWeightsByNameReqInput,
|
|
|
|
|
GetWeightsByNameReqOutput,
|
|
|
|
|
InitWeightsUpdateGroupReqInput,
|
|
|
|
|
InitWeightsUpdateGroupReqOutput,
|
|
|
|
|
OpenSessionReqInput,
|
|
|
|
|
OpenSessionReqOutput,
|
|
|
|
|
ProfileReq,
|
|
|
|
|
ReleaseMemoryOccupationReqInput,
|
|
|
|
|
ReleaseMemoryOccupationReqOutput,
|
|
|
|
|
ResumeMemoryOccupationReqInput,
|
|
|
|
|
ResumeMemoryOccupationReqOutput,
|
|
|
|
|
UpdateWeightFromDiskReqInput,
|
|
|
|
|
UpdateWeightFromDiskReqOutput,
|
|
|
|
|
UpdateWeightsFromDistributedReqInput,
|
|
|
|
|
UpdateWeightsFromDistributedReqOutput,
|
|
|
|
|
UpdateWeightsFromTensorReqInput,
|
|
|
|
|
UpdateWeightsFromTensorReqOutput,
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
|
|
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
|
|
|
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
|
|
|
|
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizerManager:
|
|
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
):
|
|
|
|
|
# Parse args
|
|
|
|
|
|
|
|
|
|
self.server_args = server_args
|
|
|
|
|
self.enable_metrics = server_args.enable_metrics
|
|
|
|
|
|
|
|
|
|
# Init inter-process communication
|
|
|
|
|
context = zmq.asyncio.Context(2)
|
|
|
|
|
self.recv_from_detokenizer = get_zmq_socket(
|
|
|
|
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
|
|
|
|
)
|
|
|
|
|
self.send_to_scheduler = get_zmq_socket(
|
|
|
|
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Read model args
|
|
|
|
|
self.model_path = server_args.model_path
|
|
|
|
|
self.served_model_name = server_args.served_model_name
|
|
|
|
|
|
|
|
|
|
# Store states
|
|
|
|
|
self.no_create_loop = False
|
|
|
|
|
|
|
|
|
|
# The event to notify the weight sync is finished.
|
|
|
|
|
self.model_update_lock = RWLock()
|
|
|
|
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
|
|
|
|
None
|
|
|
|
|
)
|
|
|
|
|
self.asyncio_tasks = set()
|
|
|
|
|
|
|
|
|
|
# For session info
|
|
|
|
|
self.session_futures = {} # session_id -> asyncio event
|
|
|
|
|
|
|
|
|
|
self._generation_manager = GenerationManager(
|
|
|
|
|
server_args=server_args,
|
|
|
|
|
on_request=self.send_to_scheduler.send_pyobj,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Others
|
|
|
|
|
self.gracefully_exit = False
|
|
|
|
|
self.init_weights_update_group_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.update_weights_from_distributed_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.update_weights_from_tensor_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.get_weights_by_name_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.release_memory_occupation_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.resume_memory_occupation_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._result_dispatcher = TypeBasedDispatcher(
|
|
|
|
|
[
|
|
|
|
|
(
|
|
|
|
|
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
|
|
|
|
self._generation_manager.handle_batch_output,
|
|
|
|
|
),
|
|
|
|
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
|
|
|
|
(
|
|
|
|
|
UpdateWeightFromDiskReqOutput,
|
|
|
|
|
self._handle_update_weights_from_disk_req_output,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
InitWeightsUpdateGroupReqOutput,
|
|
|
|
|
self.init_weights_update_group_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
UpdateWeightsFromDistributedReqOutput,
|
|
|
|
|
self.update_weights_from_distributed_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
UpdateWeightsFromTensorReqOutput,
|
|
|
|
|
self.update_weights_from_tensor_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
GetWeightsByNameReqOutput,
|
|
|
|
|
self.get_weights_by_name_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
ReleaseMemoryOccupationReqOutput,
|
|
|
|
|
self.release_memory_occupation_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
ResumeMemoryOccupationReqOutput,
|
|
|
|
|
self.resume_memory_occupation_communicator.handle_recv,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def generate_request(
|
|
|
|
|
self,
|
|
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
):
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
async with self.model_update_lock.reader_lock:
|
|
|
|
|
async for value in self._generation_manager.generate_request(obj, request):
|
|
|
|
|
yield value
|
|
|
|
|
|
|
|
|
|
def flush_cache(self):
|
|
|
|
|
req = FlushCacheReq()
|
|
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
|
|
|
|
|
|
def abort_request(self, rid: str):
|
|
|
|
|
self._generation_manager.abort_request(rid)
|
|
|
|
|
|
|
|
|
|
def start_profile(self):
|
|
|
|
|
req = ProfileReq.START_PROFILE
|
|
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
|
|
|
|
|
|
def stop_profile(self):
|
|
|
|
|
req = ProfileReq.STOP_PROFILE
|
|
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
|
|
|
|
|
|
async def update_weights_from_disk(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightFromDiskReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
|
|
|
|
|
# default the load format to the server_args
|
|
|
|
|
if obj.load_format is None:
|
|
|
|
|
obj.load_format = self.server_args.load_format
|
|
|
|
|
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
|
|
|
|
|
|
|
|
|
if True:
|
|
|
|
|
# Hold the lock if it is not async. This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
|
|
|
|
return await self._wait_for_model_update_from_disk(obj)
|
|
|
|
|
|
|
|
|
|
async def _wait_for_model_update_from_disk(
|
|
|
|
|
self, obj: UpdateWeightFromDiskReqInput
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
self.model_update_result = asyncio.Future()
|
|
|
|
|
if self.server_args.dp_size == 1:
|
|
|
|
|
result = await self.model_update_result
|
|
|
|
|
if result.success:
|
|
|
|
|
self.served_model_name = obj.model_path
|
|
|
|
|
self.server_args.model_path = obj.model_path
|
|
|
|
|
self.server_args.load_format = obj.load_format
|
|
|
|
|
self.model_path = obj.model_path
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
else: # self.server_args.dp_size > 1
|
|
|
|
|
self.model_update_tmp = []
|
|
|
|
|
result = await self.model_update_result
|
|
|
|
|
|
|
|
|
|
all_success = all([r.success for r in result])
|
|
|
|
|
if all_success is True:
|
|
|
|
|
self.server_args.model_path = obj.model_path
|
|
|
|
|
self.server_args.load_format = obj.load_format
|
|
|
|
|
self.model_path = obj.model_path
|
|
|
|
|
all_message = [r.message for r in result]
|
|
|
|
|
all_message = " | ".join(all_message)
|
|
|
|
|
return all_success, all_message
|
|
|
|
|
|
|
|
|
|
async def init_weights_update_group(
|
|
|
|
|
self,
|
|
|
|
|
obj: InitWeightsUpdateGroupReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be 1 for init parameter update group"
|
|
|
|
|
result = (await self.init_weights_update_group_communicator(obj))[0]
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
|
|
|
|
|
async def update_weights_from_distributed(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightsFromDistributedReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be for update weights from distributed"
|
|
|
|
|
|
|
|
|
|
# This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
|
|
|
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
|
|
|
|
|
async def update_weights_from_tensor(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightsFromTensorReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be for update weights from distributed"
|
|
|
|
|
|
|
|
|
|
# This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
|
|
|
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
|
|
|
|
|
async def get_weights_by_name(
|
|
|
|
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
results = await self.get_weights_by_name_communicator(obj)
|
|
|
|
|
all_parameters = [r.parameter for r in results]
|
|
|
|
|
if self.server_args.dp_size == 1:
|
|
|
|
|
return all_parameters[0]
|
|
|
|
|
else:
|
|
|
|
|
return all_parameters
|
|
|
|
|
|
|
|
|
|
async def release_memory_occupation(
|
|
|
|
|
self,
|
|
|
|
|
obj: ReleaseMemoryOccupationReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
):
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
await self.release_memory_occupation_communicator(obj)
|
|
|
|
|
|
|
|
|
|
async def resume_memory_occupation(
|
|
|
|
|
self,
|
|
|
|
|
obj: ResumeMemoryOccupationReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
):
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
await self.resume_memory_occupation_communicator(obj)
|
|
|
|
|
|
|
|
|
|
async def open_session(
|
|
|
|
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
|
|
|
|
|
if obj.session_id is None:
|
|
|
|
|
obj.session_id = uuid.uuid4().hex
|
|
|
|
|
elif obj.session_id in self.session_futures:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
|
|
|
|
|
self.session_futures[obj.session_id] = asyncio.Future()
|
|
|
|
|
session_id = await self.session_futures[obj.session_id]
|
|
|
|
|
del self.session_futures[obj.session_id]
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
async def close_session(
|
|
|
|
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
|
|
|
|
await self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
|
|
|
|
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
|
|
|
|
self._generation_manager.configure_logging(obj)
|
|
|
|
|
|
|
|
|
|
def create_abort_task(self, obj: GenerateReqInput):
|
|
|
|
|
# Abort the request if the client is disconnected.
|
|
|
|
|
async def abort_request():
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
if obj.is_single:
|
|
|
|
|
self.abort_request(obj.rid)
|
|
|
|
|
else:
|
|
|
|
|
for rid in obj.rid:
|
|
|
|
|
self.abort_request(rid)
|
|
|
|
|
|
|
|
|
|
background_tasks = BackgroundTasks()
|
|
|
|
|
background_tasks.add_task(abort_request)
|
|
|
|
|
return background_tasks
|
|
|
|
|
|
|
|
|
|
def auto_create_handle_loop(self):
|
|
|
|
|
if self.no_create_loop:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.no_create_loop = True
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
self.asyncio_tasks.add(
|
|
|
|
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We cannot add signal handler when the tokenizer manager is not in
|
|
|
|
|
# the main thread due to the CPython limitation.
|
|
|
|
|
if threading.current_thread() is threading.main_thread():
|
|
|
|
|
signal_handler = SignalHandler(self)
|
|
|
|
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Signal handler is not added because the tokenizer manager is "
|
|
|
|
|
"not in the main thread. This disables graceful shutdown of the "
|
|
|
|
|
"tokenizer manager when SIGTERM is received."
|
|
|
|
|
)
|
|
|
|
|
self.asyncio_tasks.add(
|
|
|
|
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def sigterm_watchdog(self):
|
|
|
|
|
while not self.gracefully_exit:
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
|
|
|
|
# Drain requests
|
|
|
|
|
while True:
|
|
|
|
|
remain_num_req = len(self._generation_manager.rid_to_state)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
|
|
|
|
)
|
|
|
|
|
if remain_num_req > 0:
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
async def handle_loop(self):
|
|
|
|
|
"""The event loop that handles requests"""
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
|
|
|
|
self._result_dispatcher(recv_obj)
|
|
|
|
|
|
|
|
|
|
def _handle_open_session_req_output(self, recv_obj):
|
|
|
|
|
self.session_futures[recv_obj.session_id].set_result(
|
|
|
|
|
recv_obj.session_id if recv_obj.success else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
|
|
|
|
if self.server_args.dp_size == 1:
|
|
|
|
|
self.model_update_result.set_result(recv_obj)
|
|
|
|
|
else: # self.server_args.dp_size > 1
|
|
|
|
|
self.model_update_tmp.append(recv_obj)
|
|
|
|
|
# set future if the all results are recevied
|
|
|
|
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
|
|
|
|
self.model_update_result.set_result(self.model_update_tmp)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def is_generation(self):
|
|
|
|
|
return self._generation_manager.model_config.is_generation
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tokenizer(self):
|
|
|
|
|
return self._generation_manager.tokenizer
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_token_id(self):
|
|
|
|
|
return self._generation_manager.model_config.image_token_id
|
|
|
|
|
|
|
|
|
|
def configure_max_req_input_len(self, max_req_input_len):
|
|
|
|
|
self._generation_manager.generation_converter.max_req_input_len = (
|
|
|
|
|
max_req_input_len
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def print_exception_wrapper(func):
|
|
|
|
|
"""
|
|
|
|
|
Sometimes an asyncio function does not print exception.
|
|
|
|
|
We do another wrapper to handle the exception.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
await func()
|
|
|
|
|
except Exception:
|
|
|
|
|
traceback = get_exception_traceback()
|
|
|
|
|
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
|
|
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SignalHandler:
|
|
|
|
|
def __init__(self, tokenizer_manager):
|
|
|
|
|
self.tokenizer_manager = tokenizer_manager
|
|
|
|
|
|
|
|
|
|
def signal_handler(self, signum=None, frame=None):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
|
|
|
|
)
|
|
|
|
|
self.tokenizer_manager.gracefully_exit = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Communicator(Generic[T]):
|
|
|
|
|
def __init__(self, sender, fan_out: int):
|
|
|
|
|
self._sender = sender
|
|
|
|
|
self._fan_out = fan_out
|
|
|
|
|
self._result_future: Optional[asyncio.Future] = None
|
|
|
|
|
self._result_values: Optional[List[T]] = None
|
|
|
|
|
|
|
|
|
|
async def __call__(self, obj):
|
|
|
|
|
self._sender.send_pyobj(obj)
|
|
|
|
|
self._result_future = asyncio.Future()
|
|
|
|
|
self._result_values = []
|
|
|
|
|
await self._result_future
|
|
|
|
|
result_values = self._result_values
|
|
|
|
|
self._result_future = self._result_values = None
|
|
|
|
|
return result_values
|
|
|
|
|
|
|
|
|
|
def handle_recv(self, recv_obj: T):
|
|
|
|
|
self._result_values.append(recv_obj)
|
|
|
|
|
if len(self._result_values) == self._fan_out:
|
|
|
|
|
self._result_future.set_result(None)
|