Add a pointer to the real KV cache pool (#4113)
This commit is contained in:
@@ -20,9 +20,8 @@ Memory pool.
|
||||
|
||||
SGLang has two levels of memory pool.
|
||||
ReqToTokenPool maps a a request to its token locations.
|
||||
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
||||
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
||||
by TokenToKVPoolAllocator
|
||||
TokenToKVPoolAllocator manages the indices to kv cache data.
|
||||
KVCache actually holds the physical kv cache.
|
||||
"""
|
||||
|
||||
import abc
|
||||
@@ -92,14 +91,40 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""A memory pool that maps a token location to its kv cache data."""
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
|
||||
self._kvcache = kvcache
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
|
||||
self.free_group = []
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -710,15 +710,6 @@ class ModelRunner:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
@@ -753,6 +744,17 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
|
||||
Reference in New Issue
Block a user