Add a pointer to the real KV cache pool (#4113)

This commit is contained in:
Zhiqiang Xie
2025-03-05 21:39:07 -08:00
committed by GitHub
parent 286e6540a6
commit aee30630d8
2 changed files with 45 additions and 38 deletions

View File

@@ -20,9 +20,8 @@ Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache.
"""
import abc
@@ -92,14 +91,40 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data."""
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
self.free_group = []
self.clear()
self._kvcache = kvcache
def available_size(self):
return len(self.free_slots)
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int):
if need_size > len(self.free_slots):
return None
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class MHATokenToKVPool(KVCache):
def __init__(