[Feature] Support mamba radix cache v0 (#11214)
Co-authored-by: hanming-lu <hanming@x.ai> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
@@ -470,6 +471,10 @@ class Scheduler(
|
||||
|
||||
# Hybrid memory pool
|
||||
self.is_hybrid = self.tp_worker.is_hybrid
|
||||
self.is_hybrid_gdn = (
|
||||
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
|
||||
)
|
||||
|
||||
if self.is_hybrid:
|
||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
||||
@@ -816,6 +821,16 @@ class Scheduler(
|
||||
disable=server_args.disable_radix_cache,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid GDN mode does not support disaggregation yet"
|
||||
self.tree_cache = MambaRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
elif server_args.enable_lmcache:
|
||||
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
||||
LMCRadixCache,
|
||||
@@ -1689,6 +1704,25 @@ class Scheduler(
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
|
||||
(
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
) = self._get_mamba_token_info()
|
||||
memory_leak = (
|
||||
full_num_used != self.tree_cache.full_protected_size()
|
||||
or mamba_num_used != self.tree_cache.mamba_protected_size()
|
||||
)
|
||||
token_msg = (
|
||||
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
@@ -1739,6 +1773,17 @@ class Scheduler(
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
elif self.is_hybrid_gdn:
|
||||
(
|
||||
num_used,
|
||||
_,
|
||||
token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_mamba_token_info()
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
@@ -1766,7 +1811,9 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self):
|
||||
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
||||
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
|
||||
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
|
||||
):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def _get_token_info(self):
|
||||
@@ -1776,6 +1823,35 @@ class Scheduler(
|
||||
token_usage = num_used / self.max_total_num_tokens
|
||||
return num_used, token_usage, available_size, evictable_size
|
||||
|
||||
def _get_mamba_token_info(self):
|
||||
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
|
||||
full_available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
full_evictable_size = (
|
||||
self.tree_cache.full_evictable_size() if is_radix_tree else 0
|
||||
)
|
||||
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
||||
mamba_evictable_size = (
|
||||
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
|
||||
)
|
||||
full_num_used = self.token_to_kv_pool_allocator.size - (
|
||||
full_available_size + full_evictable_size
|
||||
)
|
||||
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
|
||||
mamba_available_size + mamba_evictable_size
|
||||
)
|
||||
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
|
||||
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
|
||||
return (
|
||||
full_num_used,
|
||||
mamba_num_used,
|
||||
full_token_usage,
|
||||
mamba_usage,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
mamba_available_size,
|
||||
mamba_evictable_size,
|
||||
)
|
||||
|
||||
def _get_swa_token_info(self):
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
||||
|
||||
Reference in New Issue
Block a user