Memory pool fix for upstream change about eagle (#4170)
This commit is contained in:
@@ -22,7 +22,10 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
MHATokenToKVPoolHost,
|
||||||
|
TokenToKVPoolAllocator,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -127,12 +130,12 @@ class HiCacheController:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mem_pool_device: MHATokenToKVPool,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
mem_pool_host: MHATokenToKVPoolHost,
|
mem_pool_host: MHATokenToKVPoolHost,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
):
|
):
|
||||||
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||||
self.mem_pool_device = mem_pool_device
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
self.mem_pool_host = mem_pool_host
|
self.mem_pool_host = mem_pool_host
|
||||||
self.write_policy = write_policy
|
self.write_policy = write_policy
|
||||||
|
|
||||||
@@ -216,7 +219,7 @@ class HiCacheController:
|
|||||||
"""
|
"""
|
||||||
Load KV caches from host memory to device memory.
|
Load KV caches from host memory to device memory.
|
||||||
"""
|
"""
|
||||||
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||||
if device_indices is None:
|
if device_indices is None:
|
||||||
return None
|
return None
|
||||||
self.mem_pool_host.protect_load(host_indices)
|
self.mem_pool_host.protect_load(host_indices)
|
||||||
@@ -417,7 +420,7 @@ class HiCacheController:
|
|||||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||||
) -> int:
|
) -> int:
|
||||||
if self.mem_pool_host.is_synced(host_indices):
|
if self.mem_pool_host.is_synced(host_indices):
|
||||||
self.mem_pool_device.free(device_indices)
|
self.mem_pool_device_allocator.free(device_indices)
|
||||||
self.mem_pool_host.update_backup(host_indices)
|
self.mem_pool_host.update_backup(host_indices)
|
||||||
return len(device_indices)
|
return len(device_indices)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
|
||||||
MHATokenToKVPoolHost,
|
MHATokenToKVPoolHost,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||||
|
|
||||||
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool: MHATokenToKVPool,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
):
|
):
|
||||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||||
|
token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
)
|
||||||
self.cache_controller = HiCacheController(
|
self.cache_controller = HiCacheController(
|
||||||
token_to_kv_pool, self.token_to_kv_pool_host
|
token_to_kv_pool_allocator, self.token_to_kv_pool_host
|
||||||
)
|
)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
|
|||||||
# todo: dynamically adjust the threshold
|
# todo: dynamically adjust the threshold
|
||||||
self.write_through_threshold = 1
|
self.write_through_threshold = 1
|
||||||
self.load_back_threshold = 10
|
self.load_back_threshold = 10
|
||||||
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
|
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
TreeNode.counter = 0
|
TreeNode.counter = 0
|
||||||
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
def _evict_write_through_selective(self, node: TreeNode):
|
def _evict_write_through_selective(self, node: TreeNode):
|
||||||
# evict a node not initiated write to host
|
# evict a node not initiated write to host
|
||||||
self.cache_controller.mem_pool_device.free(node.value)
|
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
||||||
num_evicted = len(node.value)
|
num_evicted = len(node.value)
|
||||||
self._delete_leaf(node)
|
self._delete_leaf(node)
|
||||||
return num_evicted
|
return num_evicted
|
||||||
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
return last_node, prefix_indices
|
return last_node, prefix_indices
|
||||||
|
|
||||||
def _match_prefix_helper(
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
|
||||||
):
|
|
||||||
node.last_access_time = time.time()
|
node.last_access_time = time.time()
|
||||||
if len(key) == 0:
|
value = []
|
||||||
return
|
while len(key) > 0 and key[0] in node.children.keys():
|
||||||
|
|
||||||
if key[0] in node.children.keys():
|
|
||||||
child = node.children[key[0]]
|
child = node.children[key[0]]
|
||||||
|
child.last_access_time = time.time()
|
||||||
prefix_len = _key_match(child.key, key)
|
prefix_len = _key_match(child.key, key)
|
||||||
if prefix_len < len(child.key):
|
if prefix_len < len(child.key):
|
||||||
new_node = self._split_node(child.key, child, prefix_len)
|
new_node = self._split_node(child.key, child, prefix_len)
|
||||||
self.inc_hit_count(new_node)
|
self.inc_hit_count(new_node)
|
||||||
if not new_node.evicted:
|
if not new_node.evicted:
|
||||||
value.append(new_node.value)
|
value.append(new_node.value)
|
||||||
last_node[0] = new_node
|
node = new_node
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
self.inc_hit_count(child)
|
self.inc_hit_count(child)
|
||||||
if not child.evicted:
|
if not child.evicted:
|
||||||
value.append(child.value)
|
value.append(child.value)
|
||||||
last_node[0] = child
|
node = child
|
||||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
key = key[prefix_len:]
|
||||||
|
return value, node
|
||||||
|
|
||||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||||
# child node split into new_node -> child
|
# child node split into new_node -> child
|
||||||
|
|||||||
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: MHATokenToKVPool,
|
||||||
host_to_device_ratio: float = 2.0,
|
host_to_device_ratio: float = 3.0,
|
||||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -24,14 +24,10 @@ import requests
|
|||||||
from IPython.display import HTML, display
|
from IPython.display import HTML, display
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# type of content fields, can be only prompts or with images/videos
|
|
||||||
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_exception_traceback():
|
def get_exception_traceback():
|
||||||
etype, value, tb = sys.exc_info()
|
etype, value, tb = sys.exc_info()
|
||||||
|
|||||||
Reference in New Issue
Block a user