Support loading weights from remote instance (#8215)

Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
amysaq2023
2025-09-12 17:40:22 +08:00
committed by GitHub
parent 1b1701f1f7
commit 30d20ce84f
18 changed files with 1042 additions and 6 deletions

View File

@@ -81,6 +81,8 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsSendGroupForRemoteInstanceReqOutput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
@@ -93,6 +95,8 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
SendWeightsToRemoteInstanceReqInput,
SendWeightsToRemoteInstanceReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
@@ -538,6 +542,14 @@ class Scheduler(
(CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
InitWeightsSendGroupForRemoteInstanceReqInput,
self.init_weights_send_group_for_remote_instance,
),
(
SendWeightsToRemoteInstanceReqInput,
self.send_weights_to_remote_instance,
),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
@@ -2429,6 +2441,22 @@ class Scheduler(
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
"""Init the seed and client instance communication group."""
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
recv_req
)
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
"""Send the seed instance weights to the destination instance."""
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
return SendWeightsToRemoteInstanceReqOutput(success, message)
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0: