From 08104b56de1192468c322e6f9ba234ef6526d607 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Mon, 27 Jan 2025 12:28:17 -0800 Subject: [PATCH] Sanity check to prevent performance regression (#3171) Co-authored-by: Lianmin Zheng --- python/sglang/srt/managers/scheduler.py | 20 ++++++++++-- .../sglang/srt/mem_cache/base_prefix_cache.py | 4 +++ python/sglang/srt/mem_cache/chunk_cache.py | 3 ++ python/sglang/srt/mem_cache/radix_cache.py | 31 ++++++++++++++++++- python/sglang/srt/server_args.py | 6 ++++ 5 files changed, 60 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2b7462958..79d4db114 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -149,6 +149,7 @@ class Scheduler: if not self.spec_algorithm.is_none() else 1 ) + self.enable_hierarchical_cache = server_args.enable_hierarchical_cache # Distributed rank info self.dp_size = server_args.dp_size @@ -831,10 +832,16 @@ class Scheduler: available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - if available_size != self.max_total_num_tokens: + protected_size = self.tree_cache.protected_size() + memory_leak = available_size != ( + self.max_total_num_tokens + if not self.enable_hierarchical_cache + else self.max_total_num_tokens - protected_size + ) + if memory_leak: msg = ( "KV cache pool leak detected!" - f"{available_size=}, {self.max_total_num_tokens=}\n" + f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" ) warnings.warn(msg) if crash_on_warnings(): @@ -949,7 +956,14 @@ class Scheduler: res = adder.add_one_req(req) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: - self.batch_is_full = True + if self.enable_hierarchical_cache: + # Set batch_is_full after making sure there are requests that can be served + self.batch_is_full = len(adder.can_run_list) > 0 or ( + self.running_batch is not None + and not self.running_batch.is_empty() + ) + else: + self.batch_is_full = True break if self.server_args.prefill_only_one_req: break diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index acdd2898f..9386595a8 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -41,6 +41,10 @@ class BasePrefixCache(ABC): def evictable_size(self): pass + @abstractmethod + def protected_size(self): + raise NotImplementedError() + def total_size(self): raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index ab8965a01..b50199ca2 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache): def evictable_size(self): return 0 + + def protected_size(self): + return 0 diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 1673d4f0c..3bf87b542 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -34,7 +34,10 @@ if TYPE_CHECKING: class TreeNode: - def __init__(self): + + counter = 0 + + def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent = None self.key = None @@ -42,6 +45,23 @@ class TreeNode: self.lock_ref = 0 self.last_access_time = time.time() + self.hit_count = 0 + # indicating the node is loading KV cache from host + self.loading = False + # store the host indices of KV cache + self.host_value = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time @@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache): self.root_node.value = [] self.root_node.lock_ref = 1 self.evictable_size_ = 0 + self.protected_size_ = 0 def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: """Find the matching prefix from the radix tree. @@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache): while node != self.root_node: if node.lock_ref == 0: self.evictable_size_ -= len(node.value) + self.protected_size_ += len(node.value) delta -= len(node.value) node.lock_ref += 1 node = node.parent @@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache): while node != self.root_node: if node.lock_ref == 1: self.evictable_size_ += len(node.value) + self.protected_size_ -= len(node.value) delta += len(node.value) node.lock_ref -= 1 node = node.parent @@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache): def evictable_size(self): return self.evictable_size_ + def protected_size(self): + # protected size refers to the size of the cache that is locked + return self.protected_size_ + ##### Internal Helper Functions ##### def _match_prefix_helper( @@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache): self.evictable_size_ -= len(node.key) def _total_size_helper(self, node: TreeNode): + if node.evicted: + return 0 x = len(node.value) for child in node.children.values(): x += self._total_size_helper(child) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7bee34657..f9340e477 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -163,6 +163,7 @@ class ServerArgs: # Custom logit processor enable_custom_logit_processor: bool = False tool_call_parser: str = None + enable_hierarchical_cache: bool = False def __post_init__(self): # Set missing default values @@ -892,6 +893,11 @@ class ServerArgs: default=ServerArgs.tool_call_parser, help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", ) + parser.add_argument( + "--enable-hierarchical-cache", + action="store_true", + help="Enable hierarchical cache", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):