Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
@@ -516,6 +520,19 @@ class Scheduler:
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.init_weights_update_group(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
InitWeightsUpdateGroupReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
||||
success, message = self.update_weights_from_distributed(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1378,6 +1395,23 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
"""Update the online model parameter."""
|
||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
Reference in New Issue
Block a user