udate weights from disk (#2265)

This commit is contained in:
Chayenne
2024-11-29 17:17:00 -08:00
committed by GitHub
parent b53d6cbda3
commit 7d5d1d3d29
11 changed files with 54 additions and 40 deletions

View File

@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -506,10 +506,10 @@ class Scheduler:
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
@@ -1363,9 +1363,9 @@ class Scheduler:
req.to_abort = True
break
def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"