CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

This commit is contained in:
fzyzcjy
2025-01-14 03:38:51 +08:00
committed by GitHub
parent d08c77c434
commit 923f518337
12 changed files with 406 additions and 60 deletions

View File

@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
@@ -88,6 +92,7 @@ from sglang.srt.utils import (
set_random_seed,
suppress_other_loggers,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -357,6 +362,10 @@ class Scheduler:
t.start()
self.parent_process = psutil.Process().parent()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
@@ -519,6 +528,12 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
self.release_memory_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
self.resume_memory_occupation()
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
@@ -1538,6 +1553,20 @@ class Scheduler:
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
@@ -1576,6 +1605,20 @@ class Scheduler:
del self.sessions[session_id]
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,