Online weight updates from torch.distributed (#2279)

This commit is contained in:
Chayenne
2024-12-01 23:23:18 -08:00
committed by GitHub
parent 28bc60dcab
commit 983bfcf386
12 changed files with 1120 additions and 61 deletions

View File

@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -516,6 +520,19 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
success, message = self.init_weights_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitWeightsUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
success, message = self.update_weights_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
@@ -1378,6 +1395,23 @@ class Scheduler:
logger.error(message)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter