Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
SendWeightsToRemoteInstanceReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
@@ -538,6 +542,14 @@ class Scheduler(
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
self.init_weights_send_group_for_remote_instance,
|
||||
),
|
||||
(
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
self.send_weights_to_remote_instance,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
self.update_weights_from_distributed,
|
||||
@@ -2429,6 +2441,22 @@ class Scheduler(
|
||||
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||
return recv_req
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
"""Init the seed and client instance communication group."""
|
||||
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
||||
recv_req
|
||||
)
|
||||
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
"""Send the seed instance weights to the destination instance."""
|
||||
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
||||
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
t = recv_req.forward_sleep_time
|
||||
if t is not None and t <= 0:
|
||||
|
||||
Reference in New Issue
Block a user