[Feature] Support mamba radix cache v0 (#11214)

Co-authored-by: hanming-lu <hanming@x.ai>
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
Yi Zhang
2025-10-13 11:57:15 +08:00
committed by GitHub
parent 19ba16aa3d
commit a55cf5304a
10 changed files with 1593 additions and 55 deletions

View File

@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
logger = logging.getLogger(__name__)
@@ -382,26 +385,10 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if config := self.mambaish_config:
if config := self.mamba2_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_running_requests
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
self.server_args.dp_size
if self.server_args.enable_dp_attention
else 1
)
)
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
@@ -1330,15 +1317,60 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def handle_max_mamba_cache(self, total_rest_memory):
config = self.mambaish_config
server_args = self.server_args
assert config is not None
speculativa_ratio = (
0
if server_args.speculative_num_draft_tokens is None
else server_args.speculative_num_draft_tokens
)
if (
server_args.disable_radix_cache
or config.mamba2_cache_params.mamba_cache_per_req == 0
):
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
if server_args.max_mamba_cache_size is None:
if server_args.max_running_requests is not None:
server_args.max_mamba_cache_size = server_args.max_running_requests
else:
server_args.max_mamba_cache_size = 512
else:
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
# solve the equations:
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
mamba_state_memory_raw = (
total_rest_memory
* server_args.mamba_full_memory_ratio
/ (1 + server_args.mamba_full_memory_ratio)
)
# calculate the max_mamba_cache_size based on the given total mamba memory
server_args.max_mamba_cache_size = int(
(mamba_state_memory_raw * (1 << 30))
// config.mamba2_cache_params.mamba_cache_per_req
// (1 + speculativa_ratio)
)
if self.hybrid_gdn_config is not None:
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
server_args.dp_size if server_args.enable_dp_attention else 1
)
mamba_state_memory = (
server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
* (1 + speculativa_ratio)
/ (1 << 30)
)
return total_rest_memory - mamba_state_memory
@property
def hybrid_gdn_config(self):
config = self.model_config.hf_config
@@ -1511,8 +1543,16 @@ class ModelRunner:
),
4096,
)
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
ratio = (
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
if not self.server_args.disable_radix_cache
else 1
)
max_num_reqs = min(
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if self.is_draft_worker:
@@ -1595,6 +1635,7 @@ class ModelRunner:
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=self.server_args.max_mamba_cache_size,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,