[Feature] Support mamba radix cache v0 (#11214)

Co-authored-by: hanming-lu <hanming@x.ai>
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
Yi Zhang
2025-10-13 11:57:15 +08:00
committed by GitHub
parent 19ba16aa3d
commit a55cf5304a
10 changed files with 1593 additions and 55 deletions

View File

@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import (
alloc_for_extend,
alloc_token_slots,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
@@ -522,6 +523,7 @@ class Req:
# Memory pool info
self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
# Check finish
self.tokenizer = None
@@ -727,7 +729,12 @@ class Req:
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
@@ -877,6 +884,7 @@ class Req:
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool