From 1a31229cd4c29fb529cf82ea6995caeb84f214fc Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:47:33 -0400 Subject: [PATCH] fix: radix cache memory accounting (#10637) Signed-off-by: Alex Chi Z --- python/sglang/srt/mem_cache/radix_cache.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 2f818770a..edb0495dc 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -267,7 +267,7 @@ class RadixCache(BasePrefixCache): """ key.token_ids = self.key_convert_fn(key.token_ids) - if self.disable or len(key) == 0: + def empty_match_result(): return MatchResult( device_indices=torch.empty( (0,), @@ -278,10 +278,16 @@ class RadixCache(BasePrefixCache): last_host_node=self.root_node, ) + if self.disable or len(key) == 0: + return empty_match_result() + if self.page_size != 1: page_aligned_len = len(key) // self.page_size * self.page_size key = key[:page_aligned_len] + if len(key) == 0: + return empty_match_result() + value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) @@ -475,9 +481,9 @@ class RadixCache(BasePrefixCache): delta = 0 while node != self.root_node: if node.lock_ref == 0: - self.evictable_size_ -= len(node.value) - self.protected_size_ += len(node.value) - delta -= len(node.value) + self.evictable_size_ -= len(node.key) + self.protected_size_ += len(node.key) + delta -= len(node.key) node.lock_ref += 1 node = node.parent return delta @@ -489,9 +495,9 @@ class RadixCache(BasePrefixCache): delta = 0 while node != self.root_node: if node.lock_ref == 1: - self.evictable_size_ += len(node.value) - self.protected_size_ -= len(node.value) - delta += len(node.value) + self.evictable_size_ += len(node.key) + self.protected_size_ -= len(node.key) + delta += len(node.key) node.lock_ref -= 1 node = node.parent return delta @@ -589,7 +595,7 @@ class RadixCache(BasePrefixCache): new_node.key = key new_node.value = value node.children[child_key] = new_node - self.evictable_size_ += len(value) + self.evictable_size_ += len(key) self._record_store_event(new_node) return total_prefix_length