Memory pool fix for upstream change about eagle (#4170)
This commit is contained in:
@@ -22,7 +22,10 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,12 +130,12 @@ class HiCacheController:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mem_pool_device: MHATokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
|
||||
self.mem_pool_device = mem_pool_device
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
|
||||
@@ -216,7 +219,7 @@ class HiCacheController:
|
||||
"""
|
||||
Load KV caches from host memory to device memory.
|
||||
"""
|
||||
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
@@ -417,7 +420,7 @@ class HiCacheController:
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
) -> int:
|
||||
if self.mem_pool_host.is_synced(host_indices):
|
||||
self.mem_pool_device.free(device_indices)
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
self.mem_pool_host.update_backup(host_indices)
|
||||
return len(device_indices)
|
||||
else:
|
||||
|
||||
@@ -7,9 +7,9 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
|
||||
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: MHATokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
token_to_kv_pool_allocator.get_kvcache()
|
||||
)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool, self.token_to_kv_pool_host
|
||||
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = 1
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
|
||||
|
||||
def reset(self):
|
||||
TreeNode.counter = 0
|
||||
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
def _evict_write_through_selective(self, node: TreeNode):
|
||||
# evict a node not initiated write to host
|
||||
self.cache_controller.mem_pool_device.free(node.value)
|
||||
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
||||
num_evicted = len(node.value)
|
||||
self._delete_leaf(node)
|
||||
return num_evicted
|
||||
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
return last_node, prefix_indices
|
||||
|
||||
def _match_prefix_helper(
|
||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
||||
):
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return
|
||||
|
||||
if key[0] in node.children.keys():
|
||||
value = []
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
child.last_access_time = time.time()
|
||||
prefix_len = _key_match(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
last_node[0] = new_node
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# child node split into new_node -> child
|
||||
|
||||
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
host_to_device_ratio: float = 3.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
|
||||
@@ -24,14 +24,10 @@ import requests
|
||||
from IPython.display import HTML, display
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# type of content fields, can be only prompts or with images/videos
|
||||
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
|
||||
|
||||
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
|
||||
Reference in New Issue
Block a user