From d39d80830c41d574a4967018f1449e61574b078e Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:28:40 +0800 Subject: [PATCH] [KVCache]Qwen3.5 supports contiguous tensor hybrid-attn kv-cache (#6887) ### What this PR does / why we need it? Supports contiguous tensor hybrid-attn kv-cache on fullattn-mamba hybrid model, such as Qwen3Next and Qwen3.5. Due to the restrictions of Ascend operators, all KV tensors, conv tensors, and SSM tensors must be contiguous. Therefore, this PR uses the following solution to generate the KV cache: tensor1: [(kv_padding), conv , ...] tensor2: [k , ssm , ...] tensor3: [v , (mamba_padding), ...] Under this scheme, although some waste may occur, the tensors of all caches are guaranteed to be contiguous. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By CI. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: nwpu-zxr --- .../patch/platform/patch_mamba_config.py | 67 ++++++++++--------- vllm_ascend/worker/model_runner_v1.py | 63 ++++++++++++++--- 2 files changed, 89 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index beb7cec7..d0a0f81e 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -1,11 +1,12 @@ # mypy: ignore-errors +import math + import vllm.model_executor.models.config from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.config import MambaModelConfig from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size @classmethod @@ -33,33 +34,32 @@ def verify_and_update_config(cls, vllm_config) -> None: else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - # get attention page size (for 1 token) - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes + kernel_block_size = 128 + # get attention block size + attn_num_kv_heads = model_config.get_num_kv_heads(parallel_config) + attn_head_size = model_config.get_head_size() + attn_single_token_k_page_size = attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, model_config=model_config, ) - # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=model_config.max_model_len, - ).page_size_bytes + # get mamba block size + mamba_shapes = model_cls.get_mamba_state_shape_from_config(vllm_config) + mamba_dtypes = model_cls.get_mamba_state_dtype_from_config(vllm_config) + mamba_sizes = [] + for shape, dtype in zip(mamba_shapes, mamba_dtypes): + mamba_sizes.append(math.prod(shape) * get_dtype_size(dtype)) + ssm_block_page_size, conv_block_page_size = max(mamba_sizes), min(mamba_sizes) - block_alignment_bytes = 128 - - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + # NOTE(zxr): because of the limit of Ascend Hardware, we need to keep + # all cache tensors contiguous, so we align the page size of ssm_block + # and single attn_block + attn_block_size = kernel_block_size * cdiv(ssm_block_page_size, kernel_block_size * attn_single_token_k_page_size) + assert attn_single_token_k_page_size * attn_block_size == ssm_block_page_size, ( + "Cannot align ssm_page_size and attn_page_size." + ) # override attention block size if either (a) the # user has not set it or (b) the user has set it @@ -72,24 +72,25 @@ def verify_and_update_config(cls, vllm_config) -> None: ) # compute new attention page size - attn_page_size = cache_config.block_size * attn_page_size_1_token + attn_page_size = cache_config.block_size * 2 * attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size: - cache_config.mamba_page_size_padded = attn_page_size - mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size + # pad mamba page size for conv_blocks + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + conv_block_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + conv_block_page_size + mamba_padding_pct = 100 * conv_block_page_size / cache_config.mamba_page_size_padded logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " "exactly equal.", mamba_padding_pct, ) + if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + else: + cache_config.mamba_block_size = model_config.max_model_len vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 47c7c5a6..86cd598c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -45,6 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import LazyLoader from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -2546,12 +2547,28 @@ class NPUModelRunner(GPUModelRunner): kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 + layer_kv_cache_spec: dict[str, KVCacheSpec] = {} + for group_kv_cache_spec in kv_cache_config.kv_cache_groups: + for layer_name in group_kv_cache_spec.layer_names: + layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec + # If some tensors are shared by linear layers and attention layers, + # the same tensor format must be maintained even if some layers + # have only linear or attention layers, for example, the mtp layer. + self.hybrid_with_attn_and_mamba = False for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - # TODO: REFACTOR ME to sharing hybrid cache + use_mamba, use_attn = False, False + for layer_name in kv_cache_tensor.shared_by: + if isinstance(layer_kv_cache_spec[layer_name], MambaSpec): + use_mamba = True + if isinstance(layer_kv_cache_spec[layer_name], AttentionSpec): + use_attn = True + self.hybrid_with_attn_and_mamba = self.hybrid_with_attn_and_mamba or (use_mamba and use_attn) for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] - if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors: - # for mamba linear attention + if ( + "linear_attn" in layer_name or self.hybrid_with_attn_and_mamba + ) and layer_name not in kv_cache_raw_tensors: + # for mamba linear attention or attn-linear hybrid if self.vllm_config.kv_transfer_config is None: tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device) else: @@ -2560,10 +2577,9 @@ class NPUModelRunner(GPUModelRunner): tensor = self._align_memory(tensor, alignment)[: kv_cache_tensor.size] for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group - if "linear_attn" in layer_name_inner: - kv_cache_raw_tensors[layer_name_inner] = tensor - elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors: + # shared the kvcache for all shared layers + kv_cache_raw_tensors[layer_name_inner] = tensor + elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors and not use_mamba: # NOTE: We need to init k cache tensor (nope cache tensor in mla) and # v cache tensor (rope cache tensor in mla) separately to support prefill disaggregation, # as it only support the 0-dim of kv_cache is `num_blocks`. @@ -2616,7 +2632,7 @@ class NPUModelRunner(GPUModelRunner): dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size] for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group + # shared the attn kvcache for all shared layers if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: kv_cache_raw_tensors[layer_name_inner] = ( (k_tensor, v_tensor) @@ -2651,6 +2667,10 @@ class NPUModelRunner(GPUModelRunner): corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + layer_kv_cache_spec = {} + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + layer_kv_cache_spec[layer_name] = group.kv_cache_spec for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -2668,6 +2688,11 @@ class NPUModelRunner(GPUModelRunner): ] assert raw_dsa_k_tensor is not None sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba: + # Currently, we ensure that the same kvcache format is used even if there + # is no shared layer, such as the full attention mtp layer of qwen3.5, etc. + raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[layer_name], kv_cache_raw_tensors[layer_name] + sum_page_size_bytes = raw_k_tensor.numel() else: raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore layer_name @@ -2697,6 +2722,14 @@ class NPUModelRunner(GPUModelRunner): kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, ) + if self.hybrid_with_attn_and_mamba: + attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size( + kv_cache_spec.dtype + ) + conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2 + raw_kv_tensor = raw_k_tensor[conv_block_padding_size:] + raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size] + raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:] else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size @@ -2753,6 +2786,12 @@ class NPUModelRunner(GPUModelRunner): state_tensors = [] target_idx = 0 start_idx = 0 + # NOTE(zxr): in order to keep all tensor contiguous, we align ssm and kv block + # with same page size, so have to add extra padding block for kv, the overall + # layout of hybrid kv_cache on Ascend is: + # tensor1: [(kv_padding), conv , ...] + # tensor2: [k , ssm , ...] + # tensor3: [v , (mamba_padding), ...] for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): # normally, there is conv state and ssm state in this loop. And there is only # a conv state in some special models. @@ -2960,6 +2999,7 @@ class NPUModelRunner(GPUModelRunner): # NOTE: Must process Attention/MLAAttention before MambaBase to maintain # ordering expected by graph parameter update logic in attention backends. mamba_layers: dict[str, MambaBase] = {} + attn_layer_names = set() for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: @@ -2975,6 +3015,7 @@ class NPUModelRunner(GPUModelRunner): if spec := attn_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec + attn_layer_names.add(layer_name) elif isinstance(attn_module, MLAAttention): if self.use_sparse: @@ -2997,9 +3038,15 @@ class NPUModelRunner(GPUModelRunner): if len(mamba_layers) > 0: if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError("Prefix caching is not supported for Mamba yet.") + mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec + mamba_page_size_padded = spec.page_size_bytes + # align attn_page_size to mamba_page_size_padded + for layer_name in attn_layer_names: + if kv_cache_spec[layer_name].page_size_bytes < mamba_page_size_padded: + object.__setattr__(kv_cache_spec[layer_name], "page_size_padded", mamba_page_size_padded) return kv_cache_spec