Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_mamba_config_310.py
pu-zhe e8f7b2e3f1 [Refactor] [310p] Support Mamba Cache and support attn_head_size larger than 128 (#7372)
### What this PR does / why we need it?
1. Mamba Cache Support on 310P: Implemented logic to correctly
initialize and allocate KV cache for Mamba models on the 310P platform,
including handling of state tensors and page size alignment.
2. Increased Attention Head Size Support: Modified the attention backend
to support attn_head_size larger than 128 by dynamically selecting
appropriate kernel block sizes based on hardware limitations (e.g.,
block_size * head_size <= 16384).
3. Refactored KV Cache Allocation: Consolidated and improved the KV
cache allocation mechanism, moving from separate size calculation and
allocation steps to a unified _allocate_kv_cache_tensors method that
handles both Attention and Mamba specific cache structures.
4. Dynamic Mamba Config Patching: Introduced conditional loading of
Mamba configuration patches, specifically using patch_mamba_config_310
for the 310P platform to ensure platform-specific optimizations and
validations.
5. Reserve reasonable memory to allocate KV cache to avoid OOM issue
with default gpu_memory_utilization.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3.5 E2E test
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-03-19 09:16:22 +08:00

105 lines
4.3 KiB
Python

# mypy: ignore-errors
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import lcm
import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
logger = init_logger(__name__)
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
if model_config.use_mla:
raise RuntimeError("MLA is not supported on 310P currently.")
kernel_block_alignment_size = 128
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=-1,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.mamba_cache_mode == "all":
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config