[Feature] Support mamba radix cache v0 (#11214)

Co-authored-by: hanming-lu <hanming@x.ai>
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
Yi Zhang
2025-10-13 11:57:15 +08:00
committed by GitHub
parent 19ba16aa3d
commit a55cf5304a
10 changed files with 1593 additions and 55 deletions

View File

@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
@@ -470,6 +471,10 @@ class Scheduler(
# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
@@ -816,6 +821,16 @@ class Scheduler(
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
@@ -1689,6 +1704,25 @@ class Scheduler(
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
@@ -1739,6 +1773,17 @@ class Scheduler(
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
@@ -1766,7 +1811,9 @@ class Scheduler(
self._publish_kv_events()
def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()
def _get_token_info(self):
@@ -1776,6 +1823,35 @@ class Scheduler(
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
)
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
mamba_evictable_size = (
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
)
full_num_used = self.token_to_kv_pool_allocator.size - (
full_available_size + full_evictable_size
)
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
mamba_available_size + mamba_evictable_size
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
)
def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()