From 54e872d34354d2821f2567897769c31df6b16c8e Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Fri, 29 Aug 2025 10:30:54 -0700 Subject: [PATCH] [HiCache] resolve conflict between chunked-prefill and hicache hit count (#9776) --- python/sglang/srt/disaggregation/prefill.py | 2 +- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/mem_cache/chunk_cache.py | 2 +- python/sglang/srt/mem_cache/hiradix_cache.py | 17 +++++++++-------- python/sglang/srt/mem_cache/lora_radix_cache.py | 2 +- python/sglang/srt/mem_cache/radix_cache.py | 8 +++++--- python/sglang/srt/mem_cache/radix_cache_cpp.py | 2 +- python/sglang/srt/mem_cache/swa_radix_cache.py | 2 +- 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 063197618..9b80bd4ff 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) - self.tree_cache.cache_unfinished_req(self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) if self.enable_overlap: # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved self.chunked_req.tmp_end_idx = min( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1feb7c0dd..54028ce65 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1503,7 +1503,7 @@ class Scheduler( # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. chunked_req_to_exclude.add(self.chunked_req) - self.tree_cache.cache_unfinished_req(self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) # chunked request keeps its rid but will get a new req_pool_idx self.req_to_token_pool.free(self.chunked_req.req_pool_idx) if self.last_batch and self.last_batch.forward_mode.is_extend(): diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 88d923b46..1a576bfa2 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool_allocator.free(kv_indices) - def cache_unfinished_req(self, req: Req): + def cache_unfinished_req(self, req: Req, chunked=False): kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.fill_ids) ] diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 61039913a..611e94386 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -102,7 +102,7 @@ class HiRadixCache(RadixCache): self.ongoing_backup = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( - 1 if hicache_write_policy == "write_through" else 3 + 1 if hicache_write_policy == "write_through" else 2 ) self.write_through_threshold_storage = ( 1 if hicache_write_policy == "write_through" else 3 @@ -155,8 +155,9 @@ class HiRadixCache(RadixCache): self.ongoing_backup[operation_id] = node node.protect_host() - def inc_hit_count(self, node: TreeNode): - if self.cache_controller.write_policy == "write_back": + def _inc_hit_count(self, node: TreeNode, chunked=False): + # skip the hit count update for chunked requests + if self.cache_controller.write_policy == "write_back" or chunked: return node.hit_count += 1 @@ -672,11 +673,11 @@ class HiRadixCache(RadixCache): new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node - def _insert_helper(self, node: TreeNode, key: List, value): - node.last_access_time = time.monotonic() + def insert(self, key: List, value, chunked=False): if len(key) == 0: return 0 + node = self.root_node child_key = self.get_child_key_fn(key) total_prefix_length = 0 @@ -693,7 +694,7 @@ class HiRadixCache(RadixCache): self.token_to_kv_pool_host.update_synced(node.host_value) self.evictable_size_ += len(node.value) else: - self.inc_hit_count(node) + self._inc_hit_count(node, chunked) total_prefix_length += prefix_len else: # partial match, split the node @@ -703,7 +704,7 @@ class HiRadixCache(RadixCache): self.token_to_kv_pool_host.update_synced(new_node.host_value) self.evictable_size_ += len(new_node.value) else: - self.inc_hit_count(new_node) + self._inc_hit_count(new_node, chunked) total_prefix_length += prefix_len node = new_node @@ -737,7 +738,7 @@ class HiRadixCache(RadixCache): last_hash = new_node.hash_value[-1] if self.cache_controller.write_policy != "write_back": - self.inc_hit_count(new_node) + self._inc_hit_count(new_node, chunked) return total_prefix_length def _collect_leaves_device(self): diff --git a/python/sglang/srt/mem_cache/lora_radix_cache.py b/python/sglang/srt/mem_cache/lora_radix_cache.py index fa5626012..32b115cb6 100644 --- a/python/sglang/srt/mem_cache/lora_radix_cache.py +++ b/python/sglang/srt/mem_cache/lora_radix_cache.py @@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) - def cache_unfinished_req(self, req: Req): + def cache_unfinished_req(self, req: Req, chunked=False): """Cache request when it is unfinished.""" if self.disable: return diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index f6383b4ce..a586b8696 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache): last_host_node=last_node, ) - def insert(self, key: List, value=None): + def insert(self, key: List, value=None, chunked=False): if self.disable: return 0 @@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) - def cache_unfinished_req(self, req: Req): + def cache_unfinished_req(self, req: Req, chunked=False): """Cache request when it is unfinished.""" if self.disable: return @@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache): page_aligned_token_ids = token_ids[:page_aligned_len] # Radix Cache takes one ref in memory pool - new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices) + new_prefix_len = self.insert( + page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked + ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index 5234f1a0f..e9512e83f 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache): self.dec_lock_ref(req.last_node) self.req_to_token_pool.free(req.req_pool_idx) - def cache_unfinished_req(self, req: Req): + def cache_unfinished_req(self, req: Req, chunked=False): """Cache request when it is unfinished.""" assert req.req_pool_idx is not None token_ids = req.fill_ids diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 7a23eb856..0624e84e1 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) - def cache_unfinished_req(self, req: Req) -> None: + def cache_unfinished_req(self, req: Req, chunked=False) -> None: """Cache request when it is unfinished.""" if self.disable: kv_indices = self.req_to_token_pool.req_to_token[