From 1d7f7835015525964db074f7a94f32929b8c24b4 Mon Sep 17 00:00:00 2001 From: cctry Date: Tue, 14 Oct 2025 17:45:19 -0700 Subject: [PATCH] Refactor kv cache free (#11351) --- python/sglang/srt/disaggregation/decode.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 55 +++---------------- .../sglang/srt/mem_cache/base_prefix_cache.py | 2 +- python/sglang/srt/mem_cache/chunk_cache.py | 2 +- python/sglang/srt/mem_cache/radix_cache.py | 32 +++++++---- .../sglang/srt/mem_cache/radix_cache_cpp.py | 33 ++++++----- .../storage/lmcache/lmc_radix_cache.py | 6 +- .../sglang/srt/mem_cache/swa_radix_cache.py | 28 ++++++---- 8 files changed, 72 insertions(+), 90 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 9305be298..8f0e1d6b5 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -611,8 +611,8 @@ class DecodeTransferQueue: self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) - # unlock the kv cache or it will have memory leak - self.tree_cache.cache_finished_req(decode_req.req) + # release pre-allocated kv cache, but don't insert into the tree since it's failed + self.tree_cache.cache_finished_req(decode_req.req, is_insert=False) indices_to_remove.add(i) if self.scheduler.enable_metrics: self.scheduler.metrics_collector.increment_transfer_failed_reqs() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 720ac2b67..789cd12db 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -64,6 +64,7 @@ from sglang.srt.mem_cache.common import ( alloc_for_decode, alloc_for_extend, alloc_token_slots, + evict_from_tree_cache, ) from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool @@ -1406,7 +1407,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): * self.token_to_kv_pool_allocator.page_size ) - self._evict_tree_cache_if_needed(num_tokens) + evict_from_tree_cache(self.tree_cache, num_tokens) return self._is_available_size_sufficient(num_tokens) def retract_decode(self, server_args: ServerArgs): @@ -1454,6 +1455,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) + # release memory and don't insert into the tree because we need the space instantly self.release_req(idx, len(sorted_indices), server_args) if len(retracted_reqs) == 0: @@ -1478,39 +1480,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): req = self.reqs[idx] - seq_lens_cpu = self.seq_lens_cpu.numpy() if server_args.disaggregation_mode == "decode": req.offload_kv_cache( self.req_to_token_pool, self.token_to_kv_pool_allocator ) - if isinstance(self.tree_cache, ChunkCache): - # ChunkCache does not have eviction - token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : seq_lens_cpu[idx] - ] - self.token_to_kv_pool_allocator.free(token_indices) - self.req_to_token_pool.free(req.req_pool_idx) - else: - # TODO: apply more fine-grained retraction - last_uncached_pos = ( - len(req.prefix_indices) // server_args.page_size - ) * server_args.page_size - token_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] - ] - self.token_to_kv_pool_allocator.free(token_indices) - self.req_to_token_pool.free(req.req_pool_idx) - - # release the last node - if self.is_hybrid: - self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) - else: - self.tree_cache.dec_lock_ref(req.last_node) - - # NOTE(lsyin): we should use the newly evictable memory instantly. - num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get() - self._evict_tree_cache_if_needed(num_tokens) + # TODO (csy): for preempted requests, we may want to insert into the tree + self.tree_cache.cache_finished_req(req, is_insert=False) + # NOTE(lsyin): we should use the newly evictable memory instantly. + num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get() + evict_from_tree_cache(self.tree_cache, num_tokens) req.reset_for_retract() @@ -1808,24 +1787,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): enable_overlap=self.enable_overlap, ) - def _evict_tree_cache_if_needed(self, num_tokens: int): - if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)): - return - - if self.is_hybrid: - full_available_size = self.token_to_kv_pool_allocator.full_available_size() - swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() - - if full_available_size < num_tokens or swa_available_size < num_tokens: - if self.tree_cache is not None: - full_num_tokens = max(0, num_tokens - full_available_size) - swa_num_tokens = max(0, num_tokens - swa_available_size) - self.tree_cache.evict(full_num_tokens, swa_num_tokens) - else: - if self.token_to_kv_pool_allocator.available_size() < num_tokens: - if self.tree_cache is not None: - self.tree_cache.evict(num_tokens) - def _is_available_size_sufficient(self, num_tokens: int) -> bool: if self.is_hybrid: return ( diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 7c5c7246e..34df99689 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -40,7 +40,7 @@ class BasePrefixCache(ABC): pass @abstractmethod - def cache_finished_req(self, req: Req, **kwargs): + def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs): pass @abstractmethod diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 34cd2083f..bb308f077 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -49,7 +49,7 @@ class ChunkCache(BasePrefixCache): last_host_node=None, ) - def cache_finished_req(self, req: Req, insert: bool = True): + def cache_finished_req(self, req: Req, is_insert: bool = True): kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 2ffc088df..05404dc2b 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -330,18 +330,18 @@ class RadixCache(BasePrefixCache): return self._insert_helper(self.root_node, key, value) - def cache_finished_req(self, req: Req): + def cache_finished_req(self, req: Req, is_insert: bool = True): """Cache request when it finishes.""" + all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) if self.disable: kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 + req.req_pool_idx, :all_token_len ] self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) return - token_ids = (req.origin_input_ids + req.output_ids)[:-1] - all_token_len = len(token_ids) + token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len] # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len @@ -354,12 +354,9 @@ class RadixCache(BasePrefixCache): page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) - self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) - if self.is_eagle: - self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) page_aligned_token_len = ( page_aligned_len + 1 if self.is_eagle else page_aligned_len @@ -372,11 +369,22 @@ class RadixCache(BasePrefixCache): old_prefix_len -= 1 # Radix Cache takes one ref in memory pool - new_prefix_len = self.insert( - RadixKey(token_ids[:page_aligned_token_len], req.extra_key), - page_aligned_kv_indices, - ) - self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) + if is_insert: + new_prefix_len = self.insert( + RadixKey(token_ids[:page_aligned_token_len], req.extra_key), + page_aligned_kv_indices, + ) + # Free the duplicates that were already in the tree + self.token_to_kv_pool_allocator.free( + kv_indices[old_prefix_len:new_prefix_len] + ) + else: + self.token_to_kv_pool_allocator.free( + kv_indices[old_prefix_len:page_aligned_len] + ) + + # free the unaligned tail + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index a16b989fb..7994e372b 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache): def total_size(self): return self.tree.total_size() - def cache_finished_req(self, req: Req): + def cache_finished_req(self, req: Req, is_insert: bool = True): """Cache request when it finishes.""" assert req.req_pool_idx is not None - token_ids = (req.origin_input_ids + req.output_ids)[:-1] + all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) + token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len] overall_len = len(token_ids) # prefill + decode kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len] # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # it will automatically align them, but length of them should be equal old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size - new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices) + page_aligned_overall_len = overall_len // self.page_size * self.page_size - # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices - assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" - - # KVCache between old & new is newly generated, but already exists in the pool - # we need to free this newly generated kv indices - if old_prefix_len < new_prefix_len: - self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len]) + if is_insert: + new_prefix_len = self._insert( + RadixKey(token_ids, req.extra_key), kv_indices + ) + # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices + assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" + # Free duplicates that were already in the pool + if old_prefix_len < new_prefix_len: + self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len]) + else: + self.token_to_kv_pool.free( + kv_indices[old_prefix_len:page_aligned_overall_len] + ) # need to free the unaligned part, since it cannot be inserted into the radix tree - if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1 - (unaligned_len := overall_len % self.page_size) > 0 - ): + if page_aligned_overall_len < overall_len: # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it) - self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :]) + self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:]) # Remove req slot release the cache lock self.dec_lock_ref(req.last_node) diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index 36061ac14..bf31cbb38 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache): return base_res - def cache_finished_req(self, req: "Req") -> None: # type: ignore[override] + def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override] """On request completion, insert device KV into radix and store to LMCache.""" - super().cache_finished_req(req) + super().cache_finished_req(req, is_insert=is_insert) + if not is_insert: + return token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 928b207d8..566c94784 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -427,19 +427,18 @@ class SWARadixCache(BasePrefixCache): return self._insert_helper(self.root_node, key, value, prev_prefix_len) - def cache_finished_req(self, req: Req) -> None: + def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: """Cache request when it finishes.""" + all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) if self.disable: kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, - : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), + req.req_pool_idx, :all_token_len ] self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) return - token_ids = (req.origin_input_ids + req.output_ids)[:-1] - all_token_len = len(token_ids) + token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len] # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len @@ -452,7 +451,6 @@ class SWARadixCache(BasePrefixCache): page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) - self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) @@ -472,11 +470,19 @@ class SWARadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool # insert the token_ids and kv_indices into the radix tree # Note: the insert function already frees the overlapped kv_indices - new_prefix_len = self.insert( - RadixKey(token_ids[:page_aligned_token_len], req.extra_key), - page_aligned_kv_indices, - old_prefix_len, - ) + if is_insert: + new_prefix_len = self.insert( + RadixKey(token_ids[:page_aligned_token_len], req.extra_key), + page_aligned_kv_indices, + old_prefix_len, + ) + else: + self.token_to_kv_pool_allocator.free( + kv_indices[old_prefix_len:page_aligned_len] + ) + + # free the unaligned tail + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx)