Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
This commit is contained in:
@@ -36,6 +36,7 @@ from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
create_grammar_backend,
|
||||
@@ -450,8 +451,6 @@ class Scheduler(
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
# Init memory saver
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
@@ -2227,23 +2226,40 @@ class Scheduler(
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(
|
||||
caller_name="release_memory_occupation"
|
||||
)
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
tags = recv_req.tags
|
||||
import subprocess
|
||||
|
||||
if tags is None:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
self.flush_cache()
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
||||
self.memory_saver_adapter.resume()
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
tags = recv_req.tags
|
||||
if tags is None or len(tags) == 0:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model,
|
||||
self.stashed_model_static_state,
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
|
||||
Reference in New Issue
Block a user