From 273b28344bc0125ca03af5d59dc00e56c925f310 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Sat, 6 Sep 2025 00:06:08 +0000 Subject: [PATCH] [Minor] Refactors KV memory pool (#9842) --- python/sglang/srt/mem_cache/memory_pool.py | 78 ++++++++++------------ test/srt/test_swa_unittest.py | 42 +++++++----- 2 files changed, 59 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 3bde48da4..af56c580a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -130,6 +130,29 @@ class KVCache(abc.ABC): # used for chunked cpu-offloading self.cpu_offloading_chunk_size = 8192 + # default state for optional layer-wise transfer control + self.layer_transfer_counter = None + + def _finalize_allocation_log(self, num_tokens: int): + """Common logging and mem_usage computation for KV cache allocation. + Supports both tuple (K, V) size returns and single KV size returns. + """ + kv_size_bytes = self.get_kv_size_bytes() + if isinstance(kv_size_bytes, tuple): + k_size, v_size = kv_size_bytes + k_size_GB = k_size / GB + v_size_GB = v_size / GB + logger.info( + f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB" + ) + self.mem_usage = k_size_GB + v_size_GB + else: + kv_size_GB = kv_size_bytes / GB + logger.info( + f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB" + ) + self.mem_usage = kv_size_GB + @abc.abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() @@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache): self._create_buffers() - self.layer_transfer_counter = None self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if _is_cuda else None - - k_size, v_size = self.get_kv_size_bytes() - logger.info( - f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" - ) - self.mem_usage = (k_size + v_size) / GB + self._finalize_allocation_log(size) def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): @@ -427,43 +444,30 @@ class SWAKVPool(KVCache): self, size: int, size_swa: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, swa_attention_layer_ids: List[int], full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, - device: str, + token_to_kv_pool_class: KVCache = MHATokenToKVPool, + **kwargs, ): self.size = size self.size_swa = size_swa - self.dtype = dtype - self.device = device self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) - self.page_size = 1 + kwargs["page_size"] = 1 + kwargs["enable_memory_saver"] = False # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose - TokenToKVPoolClass = MHATokenToKVPool - self.swa_kv_pool = TokenToKVPoolClass( + + self.swa_kv_pool = token_to_kv_pool_class( size=size_swa, - page_size=self.page_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, layer_num=self.swa_layer_nums, - device=device, - enable_memory_saver=False, + **kwargs, ) - self.full_kv_pool = TokenToKVPoolClass( + self.full_kv_pool = token_to_kv_pool_class( size=size, - page_size=self.page_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, layer_num=self.full_layer_nums, - device=device, - enable_memory_saver=False, + **kwargs, ) self.layers_mapping: Dict[int, Tuple[int, bool]] = {} for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids): @@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache): dtype=torch.uint64, device=self.device, ) - self.layer_transfer_counter = None - - kv_size = self.get_kv_size_bytes() - logger.info( - f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB" - ) - self.mem_usage = kv_size / GB + self._finalize_allocation_log(size) def get_kv_size_bytes(self): assert hasattr(self, "kv_buffer") @@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): device=self.device, ) - self.layer_transfer_counter = None - - kv_size = self.get_kv_size_bytes() - logger.info( - f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB" - ) - self.mem_usage = kv_size / GB + self._finalize_allocation_log(size) def get_kv_size_bytes(self): assert hasattr(self, "k_buffer") diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py index e026d70af..128462029 100644 --- a/test/srt/test_swa_unittest.py +++ b/test/srt/test_swa_unittest.py @@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase): i for i in range(num_layers) if i not in full_attention_layer_ids_set ] pool = SWAKVPool( - size, - size_swa, - dtype, - num_head, - head_dim, - swa_attention_layer_ids, - full_attention_layer_ids, - device, + size=size, + size_swa=size_swa, + dtype=dtype, + num_head=num_head, + head_dim=head_dim, + swa_attention_layer_ids=swa_attention_layer_ids, + full_attention_layer_ids=full_attention_layer_ids, + device=device, + ) + alloc = SWATokenToKVPoolAllocator( + size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool ) - alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool) assert alloc.available_size() == size + size_swa index = alloc.alloc(1) assert alloc.available_size() == size_swa + size_swa - 2 @@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase): ) # setup kv pool kv_pool = SWAKVPool( - kv_size, - kv_size_swa, - dtype, - num_head, - head_dim, - swa_attention_layer_ids, - full_attention_layer_ids, - device, + size=kv_size, + size_swa=kv_size_swa, + dtype=dtype, + num_head=num_head, + head_dim=head_dim, + swa_attention_layer_ids=swa_attention_layer_ids, + full_attention_layer_ids=full_attention_layer_ids, + device=device, ) # setup token to kv pool allocator allocator = SWATokenToKVPoolAllocator( - kv_size, kv_size_swa, dtype, device, kv_pool + size=kv_size, + size_swa=kv_size_swa, + dtype=dtype, + device=device, + kvcache=kv_pool, ) # setup radix cache tree = SWARadixCache(