From dbd9435dc1c00337eb36e098d38c15b88c909681 Mon Sep 17 00:00:00 2001 From: Roger Young <42564206+rogeryoungh@users.noreply.github.com> Date: Fri, 24 Oct 2025 04:07:43 +0800 Subject: [PATCH] Fix mamba radix cache eviction logic in `alloc_req_slots` (#11616) Signed-off-by: rogeryoungh --- python/sglang/srt/managers/schedule_batch.py | 23 +------------------- python/sglang/srt/mem_cache/common.py | 11 +++++++++- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 93e11424d..be2de0cc7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -66,7 +66,7 @@ from sglang.srt.mem_cache.common import ( evict_from_tree_cache, ) from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache -from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats @@ -1080,27 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def is_empty(self): return len(self.reqs) == 0 - def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): - if isinstance(self.req_to_token_pool, HybridReqToTokenPool): - mamba_available_size = self.req_to_token_pool.mamba_pool.available_size() - if mamba_available_size < num_reqs: - if self.tree_cache is not None and isinstance( - self.tree_cache, MambaRadixCache - ): - mamba_num = max(0, num_reqs - mamba_available_size) - self.tree_cache.evict_mamba(mamba_num) - req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) - else: - req_pool_indices = self.req_to_token_pool.alloc(num_reqs) - if req_pool_indices is None: - raise RuntimeError( - "alloc_req_slots runs out of memory. " - "Please set a smaller number for `--max-running-requests`. " - f"{self.req_to_token_pool.available_size()=}, " - f"{num_reqs=}, " - ) - return req_pool_indices - def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 979b697c3..8f96ba0fe 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -10,6 +10,7 @@ import triton.language as tl from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import support_triton @@ -292,9 +293,15 @@ def alloc_req_slots( req_to_token_pool: ReqToTokenPool, num_reqs: int, reqs: list[Req] | None, + tree_cache: BasePrefixCache | None, ) -> list[int]: """Allocate request slots from the pool.""" if isinstance(req_to_token_pool, HybridReqToTokenPool): + mamba_available_size = req_to_token_pool.mamba_pool.available_size() + if mamba_available_size < num_reqs: + if tree_cache is not None and isinstance(tree_cache, MambaRadixCache): + mamba_num = max(0, num_reqs - mamba_available_size) + tree_cache.evict_mamba(mamba_num) req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) else: req_pool_indices = req_to_token_pool.alloc(num_reqs) @@ -337,7 +344,9 @@ def alloc_for_extend( extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) # Allocate req slots - req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) + req_pool_indices = alloc_req_slots( + batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache + ) req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)