Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)
This commit is contained in:
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
|
|||||||
set_random_seed,
|
set_random_seed,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -422,6 +422,34 @@ class Scheduler:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._dispatcher = TypeBasedDispatcher(
|
||||||
|
[
|
||||||
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||||
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||||
|
(FlushCacheReq, self.flush_cache_wrapped),
|
||||||
|
(AbortReq, self.abort_request),
|
||||||
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||||
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||||
|
(
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
self.update_weights_from_distributed,
|
||||||
|
),
|
||||||
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||||
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||||
|
(ProfileReq, self.profile),
|
||||||
|
(OpenSessionReqInput, self.open_session),
|
||||||
|
(CloseSessionReqInput, self.close_session),
|
||||||
|
(
|
||||||
|
ReleaseMemoryOccupationReqInput,
|
||||||
|
lambda _: self.release_memory_occupation(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ResumeMemoryOccupationReqInput,
|
||||||
|
lambda _: self.resume_memory_occupation(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def watchdog_thread(self):
|
def watchdog_thread(self):
|
||||||
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
||||||
self.watchdog_last_forward_ct = 0
|
self.watchdog_last_forward_ct = 0
|
||||||
@@ -563,57 +591,9 @@ class Scheduler:
|
|||||||
|
|
||||||
def process_input_requests(self, recv_reqs: List):
|
def process_input_requests(self, recv_reqs: List):
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
output = self._dispatcher(recv_req)
|
||||||
self.handle_generate_request(recv_req)
|
if output is not None:
|
||||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
self.handle_embedding_request(recv_req)
|
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
|
||||||
self.flush_cache()
|
|
||||||
elif isinstance(recv_req, AbortReq):
|
|
||||||
self.abort_request(recv_req)
|
|
||||||
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
|
||||||
success, message = self.update_weights_from_disk(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
UpdateWeightFromDiskReqOutput(success, message)
|
|
||||||
)
|
|
||||||
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
|
||||||
success, message = self.init_weights_update_group(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
InitWeightsUpdateGroupReqOutput(success, message)
|
|
||||||
)
|
|
||||||
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
|
||||||
success, message = self.update_weights_from_distributed(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
|
||||||
)
|
|
||||||
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
|
||||||
success, message = self.update_weights_from_tensor(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
UpdateWeightsFromTensorReqOutput(success, message)
|
|
||||||
)
|
|
||||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
|
||||||
parameter = self.get_weights_by_name(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
|
||||||
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
|
||||||
self.release_memory_occupation()
|
|
||||||
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
|
||||||
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
|
||||||
self.resume_memory_occupation()
|
|
||||||
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
|
||||||
elif isinstance(recv_req, ProfileReq):
|
|
||||||
if recv_req == ProfileReq.START_PROFILE:
|
|
||||||
self.start_profile()
|
|
||||||
else:
|
|
||||||
self.stop_profile()
|
|
||||||
elif isinstance(recv_req, OpenSessionReqInput):
|
|
||||||
session_id, success = self.open_session(recv_req)
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
OpenSessionReqOutput(session_id=session_id, success=success)
|
|
||||||
)
|
|
||||||
elif isinstance(recv_req, CloseSessionReqInput):
|
|
||||||
self.close_session(recv_req)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
|
||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
@@ -1545,6 +1525,9 @@ class Scheduler:
|
|||||||
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|
||||||
|
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
||||||
|
self.flush_cache()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
"""Flush the memory pool and cache."""
|
"""Flush the memory pool and cache."""
|
||||||
if len(self.waiting_queue) == 0 and (
|
if len(self.waiting_queue) == 0 and (
|
||||||
@@ -1597,12 +1580,12 @@ class Scheduler:
|
|||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
else:
|
else:
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return UpdateWeightFromDiskReqOutput(success, message)
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
"""Initialize the online model parameter update group."""
|
"""Initialize the online model parameter update group."""
|
||||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||||
return success, message
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self,
|
self,
|
||||||
@@ -1615,7 +1598,7 @@ class Scheduler:
|
|||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
else:
|
else:
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
"""Update the online model parameter from tensors."""
|
"""Update the online model parameter from tensors."""
|
||||||
@@ -1626,11 +1609,11 @@ class Scheduler:
|
|||||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
else:
|
else:
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||||
return parameter
|
return GetWeightsByNameReqOutput(parameter)
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self):
|
||||||
self.stashed_model_static_state = _export_static_state(
|
self.stashed_model_static_state = _export_static_state(
|
||||||
@@ -1638,6 +1621,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
self.memory_saver_adapter.pause()
|
self.memory_saver_adapter.pause()
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
|
return ReleaseMemoryOccupationReqOutput()
|
||||||
|
|
||||||
def resume_memory_occupation(self):
|
def resume_memory_occupation(self):
|
||||||
self.memory_saver_adapter.resume()
|
self.memory_saver_adapter.resume()
|
||||||
@@ -1645,6 +1629,13 @@ class Scheduler:
|
|||||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||||
)
|
)
|
||||||
del self.stashed_model_static_state
|
del self.stashed_model_static_state
|
||||||
|
return ResumeMemoryOccupationReqOutput()
|
||||||
|
|
||||||
|
def profile(self, recv_req: ProfileReq):
|
||||||
|
if recv_req == ProfileReq.START_PROFILE:
|
||||||
|
self.start_profile()
|
||||||
|
else:
|
||||||
|
self.stop_profile()
|
||||||
|
|
||||||
def start_profile(self) -> None:
|
def start_profile(self) -> None:
|
||||||
if self.profiler is None:
|
if self.profiler is None:
|
||||||
@@ -1660,20 +1651,20 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
logger.info("Profiler is done")
|
logger.info("Profiler is done")
|
||||||
|
|
||||||
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
def open_session(self, recv_req: OpenSessionReqInput):
|
||||||
# handle error
|
# handle error
|
||||||
session_id = recv_req.session_id
|
session_id = recv_req.session_id
|
||||||
if session_id in self.sessions:
|
if session_id in self.sessions:
|
||||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||||
return session_id, False
|
return OpenSessionReqOutput(session_id, False)
|
||||||
elif session_id is None:
|
elif session_id is None:
|
||||||
logger.warning(f"session id is None, cannot open.")
|
logger.warning(f"session id is None, cannot open.")
|
||||||
return session_id, False
|
return OpenSessionReqOutput(session_id, False)
|
||||||
else:
|
else:
|
||||||
self.sessions[session_id] = Session(
|
self.sessions[session_id] = Session(
|
||||||
recv_req.capacity_of_str_len, session_id
|
recv_req.capacity_of_str_len, session_id
|
||||||
)
|
)
|
||||||
return session_id, True
|
return OpenSessionReqOutput(session_id, True)
|
||||||
|
|
||||||
def close_session(self, recv_req: CloseSessionReqInput):
|
def close_session(self, recv_req: CloseSessionReqInput):
|
||||||
# handle error
|
# handle error
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
|
|||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -221,6 +221,43 @@ class TokenizerManager:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._dispatcher = TypeBasedDispatcher(
|
||||||
|
[
|
||||||
|
(BatchStrOut, self._handle_batch_output),
|
||||||
|
(BatchEmbeddingOut, self._handle_batch_output),
|
||||||
|
(BatchTokenIDOut, self._handle_batch_output),
|
||||||
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||||
|
(
|
||||||
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
self._handle_update_weights_from_disk_req_output,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
InitWeightsUpdateGroupReqOutput,
|
||||||
|
self.init_weights_update_group_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
self.update_weights_from_distributed_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
UpdateWeightsFromTensorReqOutput,
|
||||||
|
self.update_weights_from_tensor_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GetWeightsByNameReqOutput,
|
||||||
|
self.get_weights_by_name_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ReleaseMemoryOccupationReqOutput,
|
||||||
|
self.release_memory_occupation_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ResumeMemoryOccupationReqOutput,
|
||||||
|
self.resume_memory_occupation_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -712,110 +749,64 @@ class TokenizerManager:
|
|||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_obj: Union[
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
BatchStrOut,
|
self._dispatcher(recv_obj)
|
||||||
BatchEmbeddingOut,
|
|
||||||
BatchTokenIDOut,
|
|
||||||
UpdateWeightFromDiskReqOutput,
|
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
|
||||||
GetWeightsByNameReqOutput,
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
ReleaseMemoryOccupationReqOutput,
|
|
||||||
ResumeMemoryOccupationReqOutput,
|
|
||||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
|
||||||
|
|
||||||
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
def _handle_batch_output(
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
||||||
state = self.rid_to_state.get(rid, None)
|
):
|
||||||
if state is None:
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
continue
|
state = self.rid_to_state.get(rid, None)
|
||||||
|
if state is None:
|
||||||
|
continue
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"id": rid,
|
"id": rid,
|
||||||
"finish_reason": recv_obj.finished_reasons[i],
|
"finish_reason": recv_obj.finished_reasons[i],
|
||||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||||
}
|
}
|
||||||
|
|
||||||
if getattr(state.obj, "return_logprob", False):
|
if getattr(state.obj, "return_logprob", False):
|
||||||
self.convert_logprob_style(
|
self.convert_logprob_style(
|
||||||
meta_info,
|
meta_info,
|
||||||
state.obj.top_logprobs_num,
|
state.obj.top_logprobs_num,
|
||||||
state.obj.return_text_in_logprobs,
|
state.obj.return_text_in_logprobs,
|
||||||
recv_obj,
|
recv_obj,
|
||||||
i,
|
i,
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
|
||||||
meta_info.update(
|
|
||||||
{
|
|
||||||
"completion_tokens": recv_obj.completion_tokens[i],
|
|
||||||
"cached_tokens": recv_obj.cached_tokens[i],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
|
||||||
out_dict = {
|
|
||||||
"text": recv_obj.output_strs[i],
|
|
||||||
"meta_info": meta_info,
|
|
||||||
}
|
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
|
||||||
out_dict = {
|
|
||||||
"token_ids": recv_obj.output_ids[i],
|
|
||||||
"meta_info": meta_info,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
|
||||||
out_dict = {
|
|
||||||
"embedding": recv_obj.embeddings[i],
|
|
||||||
"meta_info": meta_info,
|
|
||||||
}
|
|
||||||
state.out_list.append(out_dict)
|
|
||||||
state.finished = recv_obj.finished_reasons[i] is not None
|
|
||||||
state.event.set()
|
|
||||||
|
|
||||||
if self.enable_metrics and state.obj.log_metrics:
|
|
||||||
self.collect_metrics(state, recv_obj, i)
|
|
||||||
if (
|
|
||||||
self.dump_requests_folder
|
|
||||||
and state.finished
|
|
||||||
and state.obj.log_metrics
|
|
||||||
):
|
|
||||||
self.dump_requests(state, out_dict)
|
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
|
||||||
recv_obj.session_id if recv_obj.success else None
|
|
||||||
)
|
)
|
||||||
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
|
||||||
if self.server_args.dp_size == 1:
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
self.model_update_result.set_result(recv_obj)
|
meta_info.update(
|
||||||
else: # self.server_args.dp_size > 1
|
{
|
||||||
self.model_update_tmp.append(recv_obj)
|
"completion_tokens": recv_obj.completion_tokens[i],
|
||||||
# set future if the all results are recevied
|
"cached_tokens": recv_obj.cached_tokens[i],
|
||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
}
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
)
|
||||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
|
||||||
assert (
|
if isinstance(recv_obj, BatchStrOut):
|
||||||
self.server_args.dp_size == 1
|
out_dict = {
|
||||||
), "dp_size must be 1 for init parameter update group"
|
"text": recv_obj.output_strs[i],
|
||||||
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
"meta_info": meta_info,
|
||||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
}
|
||||||
assert (
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
self.server_args.dp_size == 1
|
out_dict = {
|
||||||
), "dp_size must be 1 for update weights from distributed"
|
"token_ids": recv_obj.output_ids[i],
|
||||||
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
"meta_info": meta_info,
|
||||||
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
}
|
||||||
assert (
|
|
||||||
self.server_args.dp_size == 1
|
|
||||||
), "dp_size must be 1 for update weights from distributed"
|
|
||||||
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
|
||||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
|
||||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
|
||||||
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
|
|
||||||
self.release_memory_occupation_communicator.handle_recv(recv_obj)
|
|
||||||
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
|
|
||||||
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
|
out_dict = {
|
||||||
|
"embedding": recv_obj.embeddings[i],
|
||||||
|
"meta_info": meta_info,
|
||||||
|
}
|
||||||
|
state.out_list.append(out_dict)
|
||||||
|
state.finished = recv_obj.finished_reasons[i] is not None
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
if self.enable_metrics and state.obj.log_metrics:
|
||||||
|
self.collect_metrics(state, recv_obj, i)
|
||||||
|
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||||
|
self.dump_requests(state, out_dict)
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self,
|
self,
|
||||||
@@ -943,6 +934,20 @@ class TokenizerManager:
|
|||||||
# Schedule the task to run in the background without awaiting it
|
# Schedule the task to run in the background without awaiting it
|
||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
|
def _handle_open_session_req_output(self, recv_obj):
|
||||||
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
|
recv_obj.session_id if recv_obj.success else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
||||||
|
if self.server_args.dp_size == 1:
|
||||||
|
self.model_update_result.set_result(recv_obj)
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
|
self.model_update_tmp.append(recv_obj)
|
||||||
|
# set future if the all results are recevied
|
||||||
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
|
|
||||||
|
|
||||||
async def print_exception_wrapper(func):
|
async def print_exception_wrapper(func):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import urllib.request
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from typing import Optional, Union
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -363,3 +363,14 @@ def terminate_process(process):
|
|||||||
def print_highlight(html_content: str):
|
def print_highlight(html_content: str):
|
||||||
html_content = str(html_content).replace("\n", "<br>")
|
html_content = str(html_content).replace("\n", "<br>")
|
||||||
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
|
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
|
||||||
|
|
||||||
|
|
||||||
|
class TypeBasedDispatcher:
|
||||||
|
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
||||||
|
self._mapping = mapping
|
||||||
|
|
||||||
|
def __call__(self, obj: Any):
|
||||||
|
for ty, fn in self._mapping:
|
||||||
|
if isinstance(obj, ty):
|
||||||
|
return fn(obj)
|
||||||
|
raise ValueError(f"Invalid object: {obj}")
|
||||||
|
|||||||
Reference in New Issue
Block a user