Memory pool fix for upstream change about eagle (#4170)

This commit is contained in:
Zhiqiang Xie
2025-03-07 00:58:20 -08:00
committed by GitHub
parent 94a2b9d33e
commit 9376ac361d
4 changed files with 27 additions and 27 deletions

View File

@@ -7,9 +7,9 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
token_to_kv_pool_allocator, self.token_to_kv_pool_host
)
# record the nodes with ongoing write through
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
def reset(self):
TreeNode.counter = 0
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
def _evict_write_through_selective(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device.free(node.value)
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
return num_evicted
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return
if key[0] in node.children.keys():
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
last_node[0] = new_node
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
node = child
key = key[prefix_len:]
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child

View File

@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0,
host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):