Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)

This commit is contained in:
Stefan He
2025-06-19 00:56:37 -07:00
committed by GitHub
parent 9179ea1595
commit 3774f07825
14 changed files with 297 additions and 108 deletions

View File

@@ -35,6 +35,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
@@ -54,6 +55,7 @@ class ReqToTokenPool:
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
@@ -61,7 +63,7 @@ class ReqToTokenPool:
self.size = size
self.max_context_len = max_context_len
self.device = device
with memory_saver_adapter.region():
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
)
def _create_buffers(self):
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
end_layer,
)
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(