Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -25,7 +25,9 @@ import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user