udate weights from disk (#2265)
This commit is contained in:
@@ -43,8 +43,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
@@ -506,10 +506,10 @@ class Scheduler:
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
self.abort_request(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||
success, message = self.update_weights(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
||||
success, message = self.update_weights_from_disk(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightReqOutput(success, message)
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
@@ -1363,9 +1363,9 @@ class Scheduler:
|
||||
req.to_abort = True
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
"""In-place update of the weights."""
|
||||
success, message = self.tp_worker.update_weights(recv_req)
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
|
||||
Reference in New Issue
Block a user