CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)
This commit is contained in:
@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -88,6 +92,7 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -357,6 +362,10 @@ class Scheduler:
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Init profiler
|
||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||
self.profiler = None
|
||||
@@ -519,6 +528,12 @@ class Scheduler:
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
||||
self.release_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
||||
self.resume_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1538,6 +1553,20 @@ class Scheduler:
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
def release_memory_occupation(self):
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
self.memory_saver_adapter.resume()
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
@@ -1576,6 +1605,20 @@ class Scheduler:
|
||||
del self.sessions[session_id]
|
||||
|
||||
|
||||
def _export_static_state(model):
|
||||
return dict(
|
||||
buffers=[
|
||||
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _import_static_state(model, static_params):
|
||||
self_named_buffers = dict(model.named_buffers())
|
||||
for name, tensor in static_params["buffers"]:
|
||||
self_named_buffers[name][...] = tensor
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
|
||||
Reference in New Issue
Block a user