Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)

This commit is contained in:
Stefan He
2025-06-19 00:56:37 -07:00
committed by GitHub
parent 9179ea1595
commit 3774f07825
14 changed files with 297 additions and 108 deletions

View File

@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
decord = ["decord"]
test = [
"accelerate",

View File

@@ -0,0 +1,3 @@
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights"

View File

@@ -31,6 +31,7 @@ import numpy as np
import torch
from torch.distributed import ProcessGroup
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
self.max_context_len = max_context_len
self.device = device
self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region():
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,

View File

@@ -479,17 +479,15 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def release_memory_occupation(self):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None)
)
def resume_memory_occupation(self):
"""Resume GPU occupation."""
obj = ResumeMemoryOccupationReqInput()
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ResumeMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None)
@@ -670,11 +668,9 @@ def _launch_subprocesses(
scheduler_procs = []
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = []
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
@@ -710,6 +706,7 @@ def _launch_subprocesses(
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)

View File

@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
@dataclass
class ReleaseMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
@dataclass
class ResumeMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass

View File

@@ -36,6 +36,7 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
@@ -450,8 +451,6 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
@@ -2227,23 +2226,40 @@ class Scheduler(
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(
caller_name="release_memory_occupation"
)
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
tags = recv_req.tags
import subprocess
if tags is None:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput):

View File

@@ -35,6 +35,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
@@ -54,6 +55,7 @@ class ReqToTokenPool:
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
@@ -61,7 +63,7 @@ class ReqToTokenPool:
self.size = size
self.max_context_len = max_context_len
self.device = device
with memory_saver_adapter.region():
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
)
def _create_buffers(self):
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
end_layer,
)
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(

View File

@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import (
get_tp_group,
get_world_group,
@@ -222,6 +223,7 @@ class ModelRunner:
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
@@ -547,7 +549,7 @@ class ModelRunner:
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,

View File

@@ -1,11 +1,13 @@
import logging
import threading
import time
from abc import ABC
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
try:
import torch_memory_saver
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
_memory_saver = torch_memory_saver.torch_memory_saver
import_error = None
except ImportError as e:
import_error = e
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
def configure_subprocess(self):
raise NotImplementedError
def region(self):
def region(self, tag: str):
raise NotImplementedError
def pause(self):
def pause(self, tag: str):
raise NotImplementedError
def resume(self):
def resume(self, tag: str):
raise NotImplementedError
@property
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
"""Adapter for TorchMemorySaver with tag-based control"""
def configure_subprocess(self):
return torch_memory_saver.configure_subprocess()
def region(self):
return _primary_memory_saver.region()
def region(self, tag: str):
return _memory_saver.region(tag=tag)
def pause(self):
return _primary_memory_saver.pause()
def pause(self, tag: str):
return _memory_saver.pause(tag=tag)
def resume(self):
return _primary_memory_saver.resume()
def resume(self, tag: str):
return _memory_saver.resume(tag=tag)
@property
def enabled(self):
return _primary_memory_saver.enabled
return _memory_saver is not None and _memory_saver.enabled
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
yield
@contextmanager
def region(self):
def region(self, tag: str):
yield
def pause(self):
def pause(self, tag: str):
pass
def resume(self):
def resume(self, tag: str):
pass
@property

View File

@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
# General test models
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"