Online weight updates from torch.distributed (#2279)

This commit is contained in:
Chayenne
2024-12-01 23:23:18 -08:00
committed by GitHub
parent 28bc60dcab
commit 983bfcf386
12 changed files with 1120 additions and 61 deletions

View File

@@ -25,7 +25,9 @@ import torch
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)