[Feature] Support mamba radix cache v0 (#11214)
Co-authored-by: hanming-lu <hanming@x.ai> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
# Detect stragger ranks in model loading
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -382,26 +385,10 @@ class ModelRunner:
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
if config := self.mambaish_config:
|
||||
if config := self.mamba2_config:
|
||||
class_name = config.__class__.__name__
|
||||
logger.warning(f"{class_name} model detected, disable radix cache")
|
||||
self.server_args.disable_radix_cache = True
|
||||
if self.server_args.max_mamba_cache_size is None:
|
||||
if self.server_args.max_running_requests is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_running_requests
|
||||
)
|
||||
else:
|
||||
self.server_args.max_mamba_cache_size = 512
|
||||
if self.hybrid_gdn_config is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_mamba_cache_size
|
||||
// (
|
||||
self.server_args.dp_size
|
||||
if self.server_args.enable_dp_attention
|
||||
else 1
|
||||
)
|
||||
)
|
||||
|
||||
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
||||
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
||||
@@ -1330,15 +1317,60 @@ class ModelRunner:
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
if config := self.mambaish_config:
|
||||
rest_memory -= (
|
||||
self.server_args.max_mamba_cache_size
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
/ (1 << 30)
|
||||
)
|
||||
if self.mambaish_config is not None:
|
||||
rest_memory = self.handle_max_mamba_cache(rest_memory)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def handle_max_mamba_cache(self, total_rest_memory):
|
||||
config = self.mambaish_config
|
||||
server_args = self.server_args
|
||||
assert config is not None
|
||||
|
||||
speculativa_ratio = (
|
||||
0
|
||||
if server_args.speculative_num_draft_tokens is None
|
||||
else server_args.speculative_num_draft_tokens
|
||||
)
|
||||
if (
|
||||
server_args.disable_radix_cache
|
||||
or config.mamba2_cache_params.mamba_cache_per_req == 0
|
||||
):
|
||||
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
|
||||
if server_args.max_mamba_cache_size is None:
|
||||
if server_args.max_running_requests is not None:
|
||||
server_args.max_mamba_cache_size = server_args.max_running_requests
|
||||
else:
|
||||
server_args.max_mamba_cache_size = 512
|
||||
else:
|
||||
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
|
||||
# solve the equations:
|
||||
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
|
||||
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
|
||||
mamba_state_memory_raw = (
|
||||
total_rest_memory
|
||||
* server_args.mamba_full_memory_ratio
|
||||
/ (1 + server_args.mamba_full_memory_ratio)
|
||||
)
|
||||
# calculate the max_mamba_cache_size based on the given total mamba memory
|
||||
server_args.max_mamba_cache_size = int(
|
||||
(mamba_state_memory_raw * (1 << 30))
|
||||
// config.mamba2_cache_params.mamba_cache_per_req
|
||||
// (1 + speculativa_ratio)
|
||||
)
|
||||
|
||||
if self.hybrid_gdn_config is not None:
|
||||
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
|
||||
server_args.dp_size if server_args.enable_dp_attention else 1
|
||||
)
|
||||
mamba_state_memory = (
|
||||
server_args.max_mamba_cache_size
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
* (1 + speculativa_ratio)
|
||||
/ (1 << 30)
|
||||
)
|
||||
return total_rest_memory - mamba_state_memory
|
||||
|
||||
@property
|
||||
def hybrid_gdn_config(self):
|
||||
config = self.model_config.hf_config
|
||||
@@ -1511,8 +1543,16 @@ class ModelRunner:
|
||||
),
|
||||
4096,
|
||||
)
|
||||
|
||||
if self.mambaish_config is not None:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
ratio = (
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
|
||||
if not self.server_args.disable_radix_cache
|
||||
else 1
|
||||
)
|
||||
max_num_reqs = min(
|
||||
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
|
||||
)
|
||||
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
if self.is_draft_worker:
|
||||
@@ -1595,6 +1635,7 @@ class ModelRunner:
|
||||
elif config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
mamba_size=self.server_args.max_mamba_cache_size,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
|
||||
Reference in New Issue
Block a user