Hierarchical Caching supports MLA (#4009)
Signed-off-by: Changqi Lu <luchangqi.123@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -8,7 +8,10 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
|
||||
raise ValueError(
|
||||
"Page size larger than 1 is not yet supported in HiRadixCache."
|
||||
)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
|
||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
|
||||
else:
|
||||
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
||||
|
||||
self.tp_group = tp_cache_group
|
||||
self.page_size = page_size
|
||||
|
||||
@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
|
||||
Reference in New Issue
Block a user