Memory pool fix for upstream change about eagle (#4170)
This commit is contained in:
@@ -7,9 +7,9 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
|
||||
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: MHATokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool, self.token_to_kv_pool_host
|
||||
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = 1
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
|
||||
|
||||
def reset(self):
|
||||
TreeNode.counter = 0
|
||||
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def _evict_write_through_selective(self, node: TreeNode):
|
||||
# evict a node not initiated write to host
|
||||
self.cache_controller.mem_pool_device.free(node.value)
|
||||
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
||||
num_evicted = len(node.value)
|
||||
self._delete_leaf(node)
|
||||
return num_evicted
|
||||
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def _match_prefix_helper(
|
||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
||||
):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
value = []
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
child.last_access_time = time.time()
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
last_node[0] = new_node
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
|
||||
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
host_to_device_ratio: float = 3.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user