From 9376ac361d845b422848fbeefbfa204613ad68e9 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Fri, 7 Mar 2025 00:58:20 -0800 Subject: [PATCH] Memory pool fix for upstream change about eagle (#4170) --- .../sglang/srt/managers/cache_controller.py | 15 +++++---- python/sglang/srt/mem_cache/hiradix_cache.py | 33 ++++++++++--------- python/sglang/srt/mem_cache/memory_pool.py | 2 +- python/sglang/utils.py | 4 --- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 993c9e5c2..003836d81 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,7 +22,10 @@ from typing import List, Optional import torch -from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost +from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPoolHost, + TokenToKVPoolAllocator, +) logger = logging.getLogger(__name__) @@ -127,12 +130,12 @@ class HiCacheController: def __init__( self, - mem_pool_device: MHATokenToKVPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, mem_pool_host: MHATokenToKVPoolHost, write_policy: str = "write_through_selective", ): - - self.mem_pool_device = mem_pool_device + self.mem_pool_device_allocator = token_to_kv_pool_allocator + self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_host = mem_pool_host self.write_policy = write_policy @@ -216,7 +219,7 @@ class HiCacheController: """ Load KV caches from host memory to device memory. """ - device_indices = self.mem_pool_device.alloc(len(host_indices)) + device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) if device_indices is None: return None self.mem_pool_host.protect_load(host_indices) @@ -417,7 +420,7 @@ class HiCacheController: self, device_indices: torch.Tensor, host_indices: torch.Tensor ) -> int: if self.mem_pool_host.is_synced(host_indices): - self.mem_pool_device.free(device_indices) + self.mem_pool_device_allocator.free(device_indices) self.mem_pool_host.update_backup(host_indices) return len(device_indices) else: diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 051f66f77..28bab2869 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -7,9 +7,9 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.memory_pool import ( - MHATokenToKVPool, MHATokenToKVPoolHost, ReqToTokenPool, + TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match @@ -21,11 +21,13 @@ class HiRadixCache(RadixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool: MHATokenToKVPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, ): - self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool) + self.token_to_kv_pool_host = MHATokenToKVPoolHost( + token_to_kv_pool_allocator.get_kvcache() + ) self.cache_controller = HiCacheController( - token_to_kv_pool, self.token_to_kv_pool_host + token_to_kv_pool_allocator, self.token_to_kv_pool_host ) # record the nodes with ongoing write through @@ -35,7 +37,7 @@ class HiRadixCache(RadixCache): # todo: dynamically adjust the threshold self.write_through_threshold = 1 self.load_back_threshold = 10 - super().__init__(req_to_token_pool, token_to_kv_pool, disable=False) + super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False) def reset(self): TreeNode.counter = 0 @@ -160,7 +162,7 @@ class HiRadixCache(RadixCache): def _evict_write_through_selective(self, node: TreeNode): # evict a node not initiated write to host - self.cache_controller.mem_pool_device.free(node.value) + self.cache_controller.mem_pool_device_allocator.free(node.value) num_evicted = len(node.value) self._delete_leaf(node) return num_evicted @@ -270,28 +272,27 @@ class HiRadixCache(RadixCache): return last_node, prefix_indices - def _match_prefix_helper( - self, node: TreeNode, key: List, value, last_node: TreeNode - ): + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() - if len(key) == 0: - return - - if key[0] in node.children.keys(): + value = [] + while len(key) > 0 and key[0] in node.children.keys(): child = node.children[key[0]] + child.last_access_time = time.time() prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) self.inc_hit_count(new_node) if not new_node.evicted: value.append(new_node.value) - last_node[0] = new_node + node = new_node + break else: self.inc_hit_count(child) if not child.evicted: value.append(child.value) - last_node[0] = child - self._match_prefix_helper(child, key[prefix_len:], value, last_node) + node = child + key = key[prefix_len:] + return value, node def _split_node(self, key, child: TreeNode, split_len: int): # child node split into new_node -> child diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 94f65059e..36a1bd8d6 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -470,7 +470,7 @@ class MHATokenToKVPoolHost: def __init__( self, device_pool: MHATokenToKVPool, - host_to_device_ratio: float = 2.0, + host_to_device_ratio: float = 3.0, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): diff --git a/python/sglang/utils.py b/python/sglang/utils.py index b358a4b9e..4a751aa88 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -24,14 +24,10 @@ import requests from IPython.display import HTML, display from tqdm import tqdm -from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart from sglang.srt.utils import kill_process_tree logger = logging.getLogger(__name__) -# type of content fields, can be only prompts or with images/videos -MsgContent = Union[str, List[ChatCompletionMessageContentPart]] - def get_exception_traceback(): etype, value, tb = sys.exc_info()