Add a pointer to the real KV cache pool (#4113)
This commit is contained in:
@@ -20,9 +20,8 @@ Memory pool.
|
||||
|
||||
SGLang has two levels of memory pool.
|
||||
ReqToTokenPool maps a a request to its token locations.
|
||||
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
||||
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
||||
by TokenToKVPoolAllocator
|
||||
TokenToKVPoolAllocator manages the indices to kv cache data.
|
||||
KVCache actually holds the physical kv cache.
|
||||
"""
|
||||
|
||||
import abc
|
||||
@@ -92,14 +91,40 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""A memory pool that maps a token location to its kv cache data."""
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
|
||||
self._kvcache = kvcache
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
|
||||
self.free_group = []
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user