diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index beb7cec7..d0a0f81e 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -1,11 +1,12 @@ # mypy: ignore-errors +import math + import vllm.model_executor.models.config from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.config import MambaModelConfig from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size @classmethod @@ -33,33 +34,32 @@ def verify_and_update_config(cls, vllm_config) -> None: else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - # get attention page size (for 1 token) - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes + kernel_block_size = 128 + # get attention block size + attn_num_kv_heads = model_config.get_num_kv_heads(parallel_config) + attn_head_size = model_config.get_head_size() + attn_single_token_k_page_size = attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, model_config=model_config, ) - # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=model_config.max_model_len, - ).page_size_bytes + # get mamba block size + mamba_shapes = model_cls.get_mamba_state_shape_from_config(vllm_config) + mamba_dtypes = model_cls.get_mamba_state_dtype_from_config(vllm_config) + mamba_sizes = [] + for shape, dtype in zip(mamba_shapes, mamba_dtypes): + mamba_sizes.append(math.prod(shape) * get_dtype_size(dtype)) + ssm_block_page_size, conv_block_page_size = max(mamba_sizes), min(mamba_sizes) - block_alignment_bytes = 128 - - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + # NOTE(zxr): because of the limit of Ascend Hardware, we need to keep + # all cache tensors contiguous, so we align the page size of ssm_block + # and single attn_block + attn_block_size = kernel_block_size * cdiv(ssm_block_page_size, kernel_block_size * attn_single_token_k_page_size) + assert attn_single_token_k_page_size * attn_block_size == ssm_block_page_size, ( + "Cannot align ssm_page_size and attn_page_size." + ) # override attention block size if either (a) the # user has not set it or (b) the user has set it @@ -72,24 +72,25 @@ def verify_and_update_config(cls, vllm_config) -> None: ) # compute new attention page size - attn_page_size = cache_config.block_size * attn_page_size_1_token + attn_page_size = cache_config.block_size * 2 * attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size: - cache_config.mamba_page_size_padded = attn_page_size - mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size + # pad mamba page size for conv_blocks + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + conv_block_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + conv_block_page_size + mamba_padding_pct = 100 * conv_block_page_size / cache_config.mamba_page_size_padded logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " "exactly equal.", mamba_padding_pct, ) + if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + else: + cache_config.mamba_block_size = model_config.max_model_len vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 47c7c5a6..86cd598c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -45,6 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import LazyLoader from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -2546,12 +2547,28 @@ class NPUModelRunner(GPUModelRunner): kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 + layer_kv_cache_spec: dict[str, KVCacheSpec] = {} + for group_kv_cache_spec in kv_cache_config.kv_cache_groups: + for layer_name in group_kv_cache_spec.layer_names: + layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec + # If some tensors are shared by linear layers and attention layers, + # the same tensor format must be maintained even if some layers + # have only linear or attention layers, for example, the mtp layer. + self.hybrid_with_attn_and_mamba = False for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - # TODO: REFACTOR ME to sharing hybrid cache + use_mamba, use_attn = False, False + for layer_name in kv_cache_tensor.shared_by: + if isinstance(layer_kv_cache_spec[layer_name], MambaSpec): + use_mamba = True + if isinstance(layer_kv_cache_spec[layer_name], AttentionSpec): + use_attn = True + self.hybrid_with_attn_and_mamba = self.hybrid_with_attn_and_mamba or (use_mamba and use_attn) for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] - if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors: - # for mamba linear attention + if ( + "linear_attn" in layer_name or self.hybrid_with_attn_and_mamba + ) and layer_name not in kv_cache_raw_tensors: + # for mamba linear attention or attn-linear hybrid if self.vllm_config.kv_transfer_config is None: tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device) else: @@ -2560,10 +2577,9 @@ class NPUModelRunner(GPUModelRunner): tensor = self._align_memory(tensor, alignment)[: kv_cache_tensor.size] for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group - if "linear_attn" in layer_name_inner: - kv_cache_raw_tensors[layer_name_inner] = tensor - elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors: + # shared the kvcache for all shared layers + kv_cache_raw_tensors[layer_name_inner] = tensor + elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors and not use_mamba: # NOTE: We need to init k cache tensor (nope cache tensor in mla) and # v cache tensor (rope cache tensor in mla) separately to support prefill disaggregation, # as it only support the 0-dim of kv_cache is `num_blocks`. @@ -2616,7 +2632,7 @@ class NPUModelRunner(GPUModelRunner): dsa_k_cache_tensor = self._align_memory(dsa_k_cache_tensor, alignment)[:dsa_k_cache_size] for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group + # shared the attn kvcache for all shared layers if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: kv_cache_raw_tensors[layer_name_inner] = ( (k_tensor, v_tensor) @@ -2651,6 +2667,10 @@ class NPUModelRunner(GPUModelRunner): corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + layer_kv_cache_spec = {} + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + layer_kv_cache_spec[layer_name] = group.kv_cache_spec for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -2668,6 +2688,11 @@ class NPUModelRunner(GPUModelRunner): ] assert raw_dsa_k_tensor is not None sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + elif self.use_hybrid_blocks and self.hybrid_with_attn_and_mamba: + # Currently, we ensure that the same kvcache format is used even if there + # is no shared layer, such as the full attention mtp layer of qwen3.5, etc. + raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[layer_name], kv_cache_raw_tensors[layer_name] + sum_page_size_bytes = raw_k_tensor.numel() else: raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore layer_name @@ -2697,6 +2722,14 @@ class NPUModelRunner(GPUModelRunner): kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, ) + if self.hybrid_with_attn_and_mamba: + attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size( + kv_cache_spec.dtype + ) + conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2 + raw_kv_tensor = raw_k_tensor[conv_block_padding_size:] + raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size] + raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:] else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size @@ -2753,6 +2786,12 @@ class NPUModelRunner(GPUModelRunner): state_tensors = [] target_idx = 0 start_idx = 0 + # NOTE(zxr): in order to keep all tensor contiguous, we align ssm and kv block + # with same page size, so have to add extra padding block for kv, the overall + # layout of hybrid kv_cache on Ascend is: + # tensor1: [(kv_padding), conv , ...] + # tensor2: [k , ssm , ...] + # tensor3: [v , (mamba_padding), ...] for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): # normally, there is conv state and ssm state in this loop. And there is only # a conv state in some special models. @@ -2960,6 +2999,7 @@ class NPUModelRunner(GPUModelRunner): # NOTE: Must process Attention/MLAAttention before MambaBase to maintain # ordering expected by graph parameter update logic in attention backends. mamba_layers: dict[str, MambaBase] = {} + attn_layer_names = set() for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: @@ -2975,6 +3015,7 @@ class NPUModelRunner(GPUModelRunner): if spec := attn_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec + attn_layer_names.add(layer_name) elif isinstance(attn_module, MLAAttention): if self.use_sparse: @@ -2997,9 +3038,15 @@ class NPUModelRunner(GPUModelRunner): if len(mamba_layers) > 0: if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError("Prefix caching is not supported for Mamba yet.") + mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec + mamba_page_size_padded = spec.page_size_bytes + # align attn_page_size to mamba_page_size_padded + for layer_name in attn_layer_names: + if kv_cache_spec[layer_name].page_size_bytes < mamba_page_size_padded: + object.__setattr__(kv_cache_spec[layer_name], "page_size_padded", mamba_page_size_padded) return kv_cache_spec