Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_mamba_config.py
zxr2333 d39d80830c [KVCache]Qwen3.5 supports contiguous tensor hybrid-attn kv-cache (#6887)
### What this PR does / why we need it?
Supports contiguous tensor hybrid-attn kv-cache on fullattn-mamba hybrid
model, such as Qwen3Next and Qwen3.5.
Due to the restrictions of Ascend operators, all KV tensors, conv
tensors, and SSM tensors must be contiguous. Therefore, this PR uses the
following solution to generate the KV cache:
tensor1: [(kv_padding), conv                      , ...]
tensor2: [k                   , ssm                       , ...]
tensor3: [v                   , (mamba_padding), ...]
Under this scheme, although some waste may occur, the tensors of all
caches are guaranteed to be contiguous.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
2026-03-09 15:28:40 +08:00

97 lines
3.9 KiB
Python

# mypy: ignore-errors
import math
import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
logger = init_logger(__name__)
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
kernel_block_size = 128
# get attention block size
attn_num_kv_heads = model_config.get_num_kv_heads(parallel_config)
attn_head_size = model_config.get_head_size()
attn_single_token_k_page_size = attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype)
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba block size
mamba_shapes = model_cls.get_mamba_state_shape_from_config(vllm_config)
mamba_dtypes = model_cls.get_mamba_state_dtype_from_config(vllm_config)
mamba_sizes = []
for shape, dtype in zip(mamba_shapes, mamba_dtypes):
mamba_sizes.append(math.prod(shape) * get_dtype_size(dtype))
ssm_block_page_size, conv_block_page_size = max(mamba_sizes), min(mamba_sizes)
# NOTE(zxr): because of the limit of Ascend Hardware, we need to keep
# all cache tensors contiguous, so we align the page size of ssm_block
# and single attn_block
attn_block_size = kernel_block_size * cdiv(ssm_block_page_size, kernel_block_size * attn_single_token_k_page_size)
assert attn_single_token_k_page_size * attn_block_size == ssm_block_page_size, (
"Cannot align ssm_page_size and attn_page_size."
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
# compute new attention page size
attn_page_size = cache_config.block_size * 2 * attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype)
# pad mamba page size for conv_blocks
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size + conv_block_page_size
):
cache_config.mamba_page_size_padded = attn_page_size + conv_block_page_size
mamba_padding_pct = 100 * conv_block_page_size / cache_config.mamba_page_size_padded
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
else:
cache_config.mamba_block_size = model_config.max_model_len
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config