Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
This commit is contained in:
@@ -35,6 +35,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
@@ -54,6 +55,7 @@ class ReqToTokenPool:
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -61,7 +63,7 @@ class ReqToTokenPool:
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
with memory_saver_adapter.region():
|
||||
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
end_layer,
|
||||
)
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user