diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0d9e6275d..94f65059e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -20,9 +20,8 @@ Memory pool. SGLang has two levels of memory pool. ReqToTokenPool maps a a request to its token locations. -TokenToKVPoolAllocator maps a token location to its KV cache data. -KVCache actually holds the physical kv cache. Allocation indices are allocated -by TokenToKVPoolAllocator +TokenToKVPoolAllocator manages the indices to kv cache data. +KVCache actually holds the physical kv cache. """ import abc @@ -92,14 +91,40 @@ class ReqToTokenPool: self.free_slots = list(range(self.size)) +class KVCache(abc.ABC): + + @abc.abstractmethod + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abc.abstractmethod + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abc.abstractmethod + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + @abc.abstractmethod + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ) -> None: + raise NotImplementedError() + + class TokenToKVPoolAllocator: - """A memory pool that maps a token location to its kv cache data.""" + """An allocator managing the indices to kv cache data.""" def __init__( self, size: int, dtype: torch.dtype, device: str, + kvcache: KVCache, ): self.size = size self.dtype = dtype @@ -110,9 +135,14 @@ class TokenToKVPoolAllocator: self.free_group = [] self.clear() + self._kvcache = kvcache + def available_size(self): return len(self.free_slots) + def get_kvcache(self): + return self._kvcache + def alloc(self, need_size: int): if need_size > len(self.free_slots): return None @@ -147,31 +177,6 @@ class TokenToKVPoolAllocator: self.free_group = [] -class KVCache(abc.ABC): - - @abc.abstractmethod - def get_key_buffer(self, layer_id: int) -> torch.Tensor: - raise NotImplementedError() - - @abc.abstractmethod - def get_value_buffer(self, layer_id: int) -> torch.Tensor: - raise NotImplementedError() - - @abc.abstractmethod - def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError() - - @abc.abstractmethod - def set_kv_buffer( - self, - layer: RadixAttention, - loc: torch.Tensor, - cache_k: torch.Tensor, - cache_v: torch.Tensor, - ) -> None: - raise NotImplementedError() - - class MHATokenToKVPool(KVCache): def __init__( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1aaae45b1..a931cb15a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -710,15 +710,6 @@ class ModelRunner: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker - if self.token_to_kv_pool_allocator is None: - self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( - self.max_total_num_tokens, - dtype=self.kv_cache_dtype, - device=self.device, - ) - else: - assert self.is_draft_worker - if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla @@ -753,6 +744,17 @@ class ModelRunner: device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) + + if self.token_to_kv_pool_allocator is None: + self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) + else: + assert self.is_draft_worker + logger.info( f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"