From 70645f4d7d1447d2c5aa0667b88af92a20018b17 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sun, 20 Apr 2025 23:08:30 -0700 Subject: [PATCH] upstream hicache fixes (#5570) --- python/sglang/srt/managers/schedule_batch.py | 8 +++ python/sglang/srt/managers/scheduler.py | 2 + python/sglang/srt/mem_cache/hiradix_cache.py | 72 +++++++++++--------- python/sglang/srt/mem_cache/memory_pool.py | 27 ++++---- python/sglang/srt/server_args.py | 16 ++++- test/srt/test_hicache.py | 4 ++ test/srt/test_hicache_mla.py | 2 + test/srt/test_hicache_page.py | 4 +- 8 files changed, 89 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a324a2e5d..ddacb7441 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -571,6 +571,14 @@ class Req: self.prefix_indices, self.last_node = tree_cache.match_prefix( rid=self.rid, key=self.adjust_max_prefix_ids() ) + elif enable_hierarchical_cache: + # in case last_node is evicted during scheduling, we need to update the prefix_indices + while self.last_node.evicted: + self.prefix_indices = self.prefix_indices[ + : -len(self.last_node.host_value) + ] + self.last_node = self.last_node.parent + self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d2a601f91..3f96f106c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -489,6 +489,8 @@ class Scheduler( tp_cache_group=self.tp_cpu_group, page_size=self.page_size, hicache_ratio=server_args.hicache_ratio, + hicache_size=server_args.hicache_size, + hicache_write_policy=server_args.hicache_write_policy, ) else: self.tree_cache = RadixCache( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 2d29126b3..1e720844a 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -29,15 +29,17 @@ class HiRadixCache(RadixCache): tp_cache_group: torch.distributed.ProcessGroup, page_size: int, hicache_ratio: float, + hicache_size: int, + hicache_write_policy: str, ): self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( - self.kv_cache, hicache_ratio, page_size + self.kv_cache, hicache_ratio, hicache_size, page_size ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( - self.kv_cache, hicache_ratio, page_size + self.kv_cache, hicache_ratio, hicache_size, page_size ) else: raise ValueError(f"HiRadixCache only supports MHA and MLA yet") @@ -50,6 +52,7 @@ class HiRadixCache(RadixCache): self.token_to_kv_pool_host, page_size, load_cache_event=self.load_cache_event, + write_policy=hicache_write_policy, ) # record the nodes with ongoing write through @@ -57,7 +60,9 @@ class HiRadixCache(RadixCache): # record the node segments with ongoing load back self.ongoing_load_back = {} # todo: dynamically adjust the threshold - self.write_through_threshold = 1 + self.write_through_threshold = ( + 1 if hicache_write_policy == "write_through" else 3 + ) self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False @@ -76,7 +81,7 @@ class HiRadixCache(RadixCache): height += 1 return height - def write_backup(self, node: TreeNode): + def write_backup(self, node: TreeNode, write_back=False): host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, @@ -90,21 +95,29 @@ class HiRadixCache(RadixCache): if host_indices is not None: node.host_value = host_indices self.ongoing_write_through[node.id] = node - self.inc_lock_ref(node) + if not write_back: + # no need to lock nodes if write back + self.inc_lock_ref(node) else: return 0 return len(host_indices) def inc_hit_count(self, node: TreeNode): - if self.cache_controller.write_policy != "write_through_selective": + if node.backuped or self.cache_controller.write_policy == "write_back": return node.hit_count += 1 - if node.host_value is None and node.hit_count > self.write_through_threshold: + if node.hit_count >= self.write_through_threshold: self.write_backup(node) node.hit_count = 0 - def writing_check(self): + def writing_check(self, write_back=False): + if write_back: + # blocking till all write back complete + while len(self.ongoing_write_through) > 0: + ack_id = self.cache_controller.ack_write_queue.get() + del self.ongoing_write_through[ack_id] + return queue_size = torch.tensor( self.cache_controller.ack_write_queue.qsize(), dtype=torch.int ) @@ -143,29 +156,25 @@ class HiRadixCache(RadixCache): heapq.heapify(leaves) num_evicted = 0 - pending_nodes = [] + write_back_nodes = [] while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x.lock_ref > 0: continue - if x.host_value is None: + if not x.backuped: if self.cache_controller.write_policy == "write_back": - num_evicted += self.write_backup(x) - pending_nodes.append(x) - elif self.cache_controller.write_policy == "write_through_selective": - num_evicted += self._evict_write_through_selective(x) + # write to host if the node is not backuped + num_evicted += self.write_backup(x, write_back=True) + write_back_nodes.append(x) else: - assert ( - self.cache_controller.write_policy != "write_through" - ), "write_through should be inclusive" - raise NotImplementedError + num_evicted += self._evict_regular(x) else: - num_evicted += self._evict_write_through(x) + num_evicted += self._evict_backuped(x) for child in x.parent.children.values(): - if child in pending_nodes: + if child in write_back_nodes: continue if not child.evicted: break @@ -174,15 +183,12 @@ class HiRadixCache(RadixCache): heapq.heappush(leaves, x.parent) if self.cache_controller.write_policy == "write_back": - # blocking till all write back complete - while len(self.ongoing_write_through) > 0: - self.writing_check() - time.sleep(0.1) - for node in pending_nodes: - assert node.host_value is not None - self._evict_write_through(node) + self.writing_check(write_back=True) + for node in write_back_nodes: + assert node.backuped + self._evict_backuped(node) - def _evict_write_through(self, node: TreeNode): + def _evict_backuped(self, node: TreeNode): # evict a node already written to host num_evicted = self.cache_controller.evict_device(node.value, node.host_value) assert num_evicted > 0 @@ -190,7 +196,7 @@ class HiRadixCache(RadixCache): node.value = None return num_evicted - def _evict_write_through_selective(self, node: TreeNode): + def _evict_regular(self, node: TreeNode): # evict a node not initiated write to host self.cache_controller.mem_pool_device_allocator.free(node.value) num_evicted = len(node.value) @@ -339,11 +345,13 @@ class HiRadixCache(RadixCache): prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) + self.inc_hit_count(new_node) if not new_node.evicted: value.append(new_node.value) node = new_node break else: + self.inc_hit_count(child) if not child.evicted: value.append(child.value) node = child @@ -369,7 +377,7 @@ class HiRadixCache(RadixCache): else: new_node.value = child.value[:split_len] child.value = child.value[split_len:] - if child.host_value is not None: + if child.backuped: new_node.host_value = child.host_value[:split_len] child.host_value = child.host_value[split_len:] child.parent = new_node @@ -426,8 +434,8 @@ class HiRadixCache(RadixCache): node.children[child_key] = new_node self.evictable_size_ += len(value) - if self.cache_controller.write_policy == "write_through": - self.write_backup(new_node) + if self.cache_controller.write_policy != "write_back": + self.inc_hit_count(new_node) return total_prefix_length def _collect_leaves_device(self): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 079dfd9b1..9f5ecbdbe 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -624,26 +624,27 @@ class HostKVCache(abc.ABC): self, device_pool: MHATokenToKVPool, host_to_device_ratio: float, + host_size: int, pin_memory: bool, device: str, page_size: int, ): - assert ( - host_to_device_ratio >= 1 - ), "The host memory should be larger than the device memory with the current protocol" - # todo, other ways of configuring the size - self.device_pool = device_pool - self.host_to_device_ratio = host_to_device_ratio + self.dtype = device_pool.store_dtype self.pin_memory = pin_memory self.device = device self.page_size = page_size - - self.size = int(device_pool.size * host_to_device_ratio) + self.size_per_token = self.get_size_per_token() + if host_size > 0: + self.size = int(host_size * 1e9 // self.size_per_token) + else: + self.size = int(device_pool.size * host_to_device_ratio) # Align the host memory pool size to the page size self.size = self.size - (self.size % self.page_size) - self.dtype = device_pool.store_dtype - self.size_per_token = self.get_size_per_token() + + assert ( + self.size > device_pool.size + ), "The host memory should be larger than the device memory with the current protocol" # Verify there is enough available host memory. host_mem = psutil.virtual_memory() @@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache): self, device_pool: MHATokenToKVPool, host_to_device_ratio: float, + host_size: int, page_size: int, pin_memory: bool = True, device: str = "cpu", ): super().__init__( - device_pool, host_to_device_ratio, pin_memory, device, page_size + device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size ) def get_size_per_token(self): @@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache): self, device_pool: MLATokenToKVPool, host_to_device_ratio: float, + host_size: int, page_size: int, pin_memory: bool = True, device: str = "cpu", ): super().__init__( - device_pool, host_to_device_ratio, pin_memory, device, page_size + device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size ) def get_size_per_token(self): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e294a56f1..ba7833879 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -180,6 +180,8 @@ class ServerArgs: tool_call_parser: Optional[str] = None enable_hierarchical_cache: bool = False hicache_ratio: float = 2.0 + hicache_size: int = 0 + hicache_write_policy: str = "write_through_selective" flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None moe_dense_tp_size: Optional[int] = None @@ -1116,10 +1118,22 @@ class ServerArgs: parser.add_argument( "--hicache-ratio", type=float, - required=False, default=ServerArgs.hicache_ratio, help="The ratio of the size of host KV cache memory pool to the size of device pool.", ) + parser.add_argument( + "--hicache-size", + type=int, + default=ServerArgs.hicache_size, + help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.", + ) + parser.add_argument( + "--hicache-write-policy", + type=str, + choices=["write_back", "write_through", "write_through_selective"], + default=ServerArgs.hicache_write_policy, + help="The write policy of hierarchical cache.", + ) parser.add_argument( "--enable-deepep-moe", action="store_true", diff --git a/test/srt/test_hicache.py b/test/srt/test_hicache.py index d651aa047..3fee235ad 100644 --- a/test/srt/test_hicache.py +++ b/test/srt/test_hicache.py @@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--enable-hierarchical-cache", + "--mem-fraction-static", + 0.7, + "--hicache-size", + 100, ], ) diff --git a/test/srt/test_hicache_mla.py b/test/srt/test_hicache_mla.py index 71418470a..5d306453c 100644 --- a/test/srt/test_hicache_mla.py +++ b/test/srt/test_hicache_mla.py @@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase): other_args=[ "--trust-remote-code", "--enable-hierarchical-cache", + "--hicache-ratio", + 2, ], ) diff --git a/test/srt/test_hicache_page.py b/test/srt/test_hicache_page.py index f237af51b..c110d054e 100644 --- a/test/srt/test_hicache_page.py +++ b/test/srt/test_hicache_page.py @@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase): other_args=[ "--enable-hierarchical-cache", "--page-size", - "32", + 32, + "--hicache-write-policy", + "write-back", ], )