Add a pointer to the real KV cache pool (#4113)
This commit is contained in:
@@ -20,9 +20,8 @@ Memory pool.
|
|||||||
|
|
||||||
SGLang has two levels of memory pool.
|
SGLang has two levels of memory pool.
|
||||||
ReqToTokenPool maps a a request to its token locations.
|
ReqToTokenPool maps a a request to its token locations.
|
||||||
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
TokenToKVPoolAllocator manages the indices to kv cache data.
|
||||||
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
KVCache actually holds the physical kv cache.
|
||||||
by TokenToKVPoolAllocator
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
@@ -92,14 +91,40 @@ class ReqToTokenPool:
|
|||||||
self.free_slots = list(range(self.size))
|
self.free_slots = list(range(self.size))
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache(abc.ABC):
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def set_kv_buffer(
|
||||||
|
self,
|
||||||
|
layer: RadixAttention,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k: torch.Tensor,
|
||||||
|
cache_v: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPoolAllocator:
|
class TokenToKVPoolAllocator:
|
||||||
"""A memory pool that maps a token location to its kv cache data."""
|
"""An allocator managing the indices to kv cache data."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str,
|
device: str,
|
||||||
|
kvcache: KVCache,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
|
|||||||
self.free_group = []
|
self.free_group = []
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
|
self._kvcache = kvcache
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
|
|
||||||
|
def get_kvcache(self):
|
||||||
|
return self._kvcache
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
if need_size > len(self.free_slots):
|
if need_size > len(self.free_slots):
|
||||||
return None
|
return None
|
||||||
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
|
|||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
|
||||||
|
|
||||||
class KVCache(abc.ABC):
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def set_kv_buffer(
|
|
||||||
self,
|
|
||||||
layer: RadixAttention,
|
|
||||||
loc: torch.Tensor,
|
|
||||||
cache_k: torch.Tensor,
|
|
||||||
cache_v: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(KVCache):
|
class MHATokenToKVPool(KVCache):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -710,15 +710,6 @@ class ModelRunner:
|
|||||||
# Draft worker shares req_to_token_pool with the target worker.
|
# Draft worker shares req_to_token_pool with the target worker.
|
||||||
assert self.is_draft_worker
|
assert self.is_draft_worker
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator is None:
|
|
||||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
|
||||||
self.max_total_num_tokens,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.is_draft_worker
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
and not self.server_args.disable_mla
|
and not self.server_args.disable_mla
|
||||||
@@ -753,6 +744,17 @@ class ModelRunner:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.token_to_kv_pool_allocator is None:
|
||||||
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
kvcache=self.token_to_kv_pool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.is_draft_worker
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Memory pool end. "
|
f"Memory pool end. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
|||||||
Reference in New Issue
Block a user