Hierarchical Caching supports MLA (#4009)

Signed-off-by: Changqi Lu <luchangqi.123@bytedance.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Lu Changqi
2025-03-14 11:42:14 +08:00
committed by GitHub
parent bb37855653
commit 0e0ec70200
4 changed files with 231 additions and 38 deletions

View File

@@ -8,7 +8,10 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
MLATokenToKVPool,
MLATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
self.tp_group = tp_cache_group
self.page_size = page_size
@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
node = child