Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
This commit is contained in:
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
|
||||
3
python/sglang/srt/constants.py
Normal file
3
python/sglang/srt/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# GPU Memory Types
|
||||
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
|
||||
GPU_MEMORY_TYPE_WEIGHTS = "weights"
|
||||
@@ -31,6 +31,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.pre_alloc_size = pre_alloc_size
|
||||
with memory_saver_adapter.region():
|
||||
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size + pre_alloc_size, max_context_len),
|
||||
dtype=torch.int32,
|
||||
|
||||
@@ -479,17 +479,15 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||
)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU occupation temporarily."""
|
||||
obj = ReleaseMemoryOccupationReqInput()
|
||||
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.release_memory_occupation(obj, None)
|
||||
)
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
"""Resume GPU occupation."""
|
||||
obj = ResumeMemoryOccupationReqInput()
|
||||
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||
@@ -670,11 +668,9 @@ def _launch_subprocesses(
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
# Launch tensor parallel scheduler processes
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
@@ -710,6 +706,7 @@ def _launch_subprocesses(
|
||||
writer,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqInput:
|
||||
pass
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqInput:
|
||||
pass
|
||||
# Optional tags to identify the memory region, which is primarily used for RL
|
||||
# Currently we only support `weights` and `kv_cache`
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -36,6 +36,7 @@ from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
create_grammar_backend,
|
||||
@@ -450,8 +451,6 @@ class Scheduler(
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
# Init memory saver
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
@@ -2227,23 +2226,40 @@ class Scheduler(
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(
|
||||
caller_name="release_memory_occupation"
|
||||
)
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
tags = recv_req.tags
|
||||
import subprocess
|
||||
|
||||
if tags is None:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
self.flush_cache()
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
||||
self.memory_saver_adapter.resume()
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
tags = recv_req.tags
|
||||
if tags is None or len(tags) == 0:
|
||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model,
|
||||
self.stashed_model_static_state,
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
|
||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
||||
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def slow_down(self, recv_req: SlowDownReqInput):
|
||||
|
||||
@@ -35,6 +35,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
@@ -54,6 +55,7 @@ class ReqToTokenPool:
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -61,7 +63,7 @@ class ReqToTokenPool:
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
with memory_saver_adapter.region():
|
||||
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
end_layer,
|
||||
)
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
@@ -222,6 +223,7 @@ class ModelRunner:
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
@@ -547,7 +549,7 @@ class ModelRunner:
|
||||
monkey_patch_vllm_parallel_state()
|
||||
monkey_patch_isinstance_for_vllm_base_layer()
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
try:
|
||||
import torch_memory_saver
|
||||
|
||||
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
||||
_memory_saver = torch_memory_saver.torch_memory_saver
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
|
||||
def configure_subprocess(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def region(self):
|
||||
def region(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def pause(self):
|
||||
def pause(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def resume(self):
|
||||
def resume(self, tag: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||
"""Adapter for TorchMemorySaver with tag-based control"""
|
||||
|
||||
def configure_subprocess(self):
|
||||
return torch_memory_saver.configure_subprocess()
|
||||
|
||||
def region(self):
|
||||
return _primary_memory_saver.region()
|
||||
def region(self, tag: str):
|
||||
return _memory_saver.region(tag=tag)
|
||||
|
||||
def pause(self):
|
||||
return _primary_memory_saver.pause()
|
||||
def pause(self, tag: str):
|
||||
return _memory_saver.pause(tag=tag)
|
||||
|
||||
def resume(self):
|
||||
return _primary_memory_saver.resume()
|
||||
def resume(self, tag: str):
|
||||
return _memory_saver.resume(tag=tag)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return _primary_memory_saver.enabled
|
||||
return _memory_saver is not None and _memory_saver.enabled
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def region(self):
|
||||
def region(self, tag: str):
|
||||
yield
|
||||
|
||||
def pause(self):
|
||||
def pause(self, tag: str):
|
||||
pass
|
||||
|
||||
def resume(self):
|
||||
def resume(self, tag: str):
|
||||
pass
|
||||
|
||||
@property
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
|
||||
# General test models
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user