From e8f7b2e3f19c8ec6a1ac0e323e250724ca2539d1 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 19 Mar 2026 09:16:22 +0800 Subject: [PATCH] [Refactor] [310p] Support Mamba Cache and support attn_head_size larger than 128 (#7372) ### What this PR does / why we need it? 1. Mamba Cache Support on 310P: Implemented logic to correctly initialize and allocate KV cache for Mamba models on the 310P platform, including handling of state tensors and page size alignment. 2. Increased Attention Head Size Support: Modified the attention backend to support attn_head_size larger than 128 by dynamically selecting appropriate kernel block sizes based on hardware limitations (e.g., block_size * head_size <= 16384). 3. Refactored KV Cache Allocation: Consolidated and improved the KV cache allocation mechanism, moving from separate size calculation and allocation steps to a unified _allocate_kv_cache_tensors method that handles both Attention and Mamba specific cache structures. 4. Dynamic Mamba Config Patching: Introduced conditional loading of Mamba configuration patches, specifically using patch_mamba_config_310 for the 310P platform to ensure platform-specific optimizations and validations. 5. Reserve reasonable memory to allocate KV cache to avoid OOM issue with default gpu_memory_utilization. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3.5 E2E test - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: pu-zhe --- vllm_ascend/_310p/attention/attention_v1.py | 4 + vllm_ascend/_310p/model_runner_310p.py | 243 +++++++++++------- vllm_ascend/_310p/worker_310p.py | 53 ++++ vllm_ascend/patch/platform/__init__.py | 7 +- .../patch/platform/patch_mamba_config_310.py | 104 ++++++++ 5 files changed, 314 insertions(+), 97 deletions(-) create mode 100644 vllm_ascend/patch/platform/patch_mamba_config_310.py diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index 3ba75604..919462fd 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -78,6 +78,10 @@ class AscendAttentionBackend310(AscendAttentionBackend): """ return AscendAttentionMetadataBuilder310 + @staticmethod + def get_supported_kernel_block_sizes() -> list[int]: + return [128, 64] + class AscendAttentionBackendImpl310(AscendAttentionBackendImpl): """ diff --git a/vllm_ascend/_310p/model_runner_310p.py b/vllm_ascend/_310p/model_runner_310p.py index 19e3acde..4a071f12 100644 --- a/vllm_ascend/_310p/model_runner_310p.py +++ b/vllm_ascend/_310p/model_runner_310p.py @@ -17,6 +17,7 @@ from __future__ import annotations +import math from contextlib import contextmanager, nullcontext import numpy as np @@ -24,14 +25,24 @@ import torch import torch_npu from vllm.config import CUDAGraphMode from vllm.logger import logger +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + EncoderOnlyAttentionSpec, + KVCacheConfig, + KVCacheSpec, + MambaSpec, + UniformTypeKVCacheSpecs, +) from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.worker.npu_input_batch import NPUInputBatch _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN = 1 +_ATTENTION_BLOCK_SIZE_LIMIT = 128 * 128 class NPUModelRunner310(NPUModelRunner): @@ -184,6 +195,7 @@ class NPUModelRunner310(NPUModelRunner): def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ + Override the base class method. Initialize the memory buffer for KV cache. Args: @@ -199,10 +211,8 @@ class NPUModelRunner310(NPUModelRunner): raise ValueError("Deepseek Sparse Attention is not supported for 310P.") if self.model_config.use_mla: raise ValueError("MLAAttention is not supported for 310P.") - # Initialize the memory size for KV cache - kv_cache_size = self._calculate_kv_cache_tensors_size(kv_cache_config) - # Allocate and reshape KV cache Tensors - kv_caches = self._allocate_kv_cache_and_reshape_tensors(kv_cache_config, kv_cache_size) + # Initialize the memory buffer for KV cache + kv_caches = self._allocate_kv_cache_tensors(kv_cache_config) # Set up cross-layer KV cache sharing for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) @@ -213,7 +223,7 @@ class NPUModelRunner310(NPUModelRunner): bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches - def _calculate_kv_cache_tensors_size(self, kv_cache_config: KVCacheConfig) -> dict[str, int]: + def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initializes the KV cache size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -221,75 +231,57 @@ class NPUModelRunner310(NPUModelRunner): Args: kv_cache_config: The KV cache config Returns: - dict[str, int]: A map between layer names to their - corresponding memory buffer size. + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer. """ # init kv cache tensors - kv_cache_sizes: dict[str, int] = {} + kv_cache: dict[str, list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]] = {} + # get kv cache spec for each layer + layer_kv_cache_spec: dict[str, KVCacheSpec] = {} + for group_kv_cache_spec in kv_cache_config.kv_cache_groups: + for layer_name in group_kv_cache_spec.layer_names: + layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec + # Allocate kv cache buffers according to the kv_cache_config and kv_cache_spec for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - # TODO: REFACTOR ME to sharing hybrid cache for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] - if "linear_attn" in layer_name and layer_name not in kv_cache_sizes: - # for mamba linear attention - kv_cache_size = kv_cache_tensor.size - for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group - if "linear_attn" in layer_name_inner: - kv_cache_sizes[layer_name_inner] = kv_cache_size - elif "attn" in layer_name and layer_name not in kv_cache_sizes: - kv_tensor_split_factor = 2 - kv_tensor_size = int(kv_cache_tensor.size // kv_tensor_split_factor) - for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the self_attn specs in the same group - if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: - kv_cache_sizes[layer_name_inner] = kv_tensor_size - - layer_names = set() - for group in kv_cache_config.kv_cache_groups: - for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue - layer_names.add(layer_name) - assert layer_names == set(kv_cache_sizes.keys()), "Some layers are not correctly initialized" - - return kv_cache_sizes - - def _allocate_kv_cache_and_reshape_tensors( - self, - kv_cache_config: KVCacheConfig, - kv_cache_sizes: dict[str, int], - ) -> dict[str, torch.Tensor]: - """ - Allocate the KV cache tensors to the desired shape and dtype. - - Args: - kv_cache_config: The KV cache config - kv_cache_sizes: The KV cache size of each layer - Returns: - dict[str, torch.Tensor]: A map between layer names to their - corresponding memory buffer for KV cache. - """ - kv_caches: dict[str, torch.Tensor] = {} - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - attn_backend = group.backend - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - if isinstance(kv_cache_spec, AttentionSpec): - kv_tensor_size = kv_cache_sizes[layer_name] - assert kv_tensor_size is not None - sum_page_size_bytes = kv_tensor_size * 2 - assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 - num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes + if "linear_attn" in layer_name and layer_name not in kv_cache: + cache_spec = layer_kv_cache_spec[layer_name] + assert isinstance(cache_spec, MambaSpec) + assert kv_cache_tensor.size % cache_spec.page_size_bytes == 0 + num_blocks = kv_cache_tensor.size // cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks - - if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks: - block_size = attn_backend.get_supported_kernel_block_sizes()[0] - + raw_tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device) + state_tensors = [] + target_idx = 0 + start_idx = 0 + for shape, dtype in zip(cache_spec.shapes, cache_spec.dtypes): + target_shape = (num_blocks, *shape) + target_idx += math.prod(target_shape) * get_dtype_size(dtype) + tensor = raw_tensor[start_idx:target_idx].view(dtype).view(target_shape) + start_idx = target_idx + state_tensors.append(tensor) + for layer_name_inner in kv_cache_tensor.shared_by: + if "linear_attn" in layer_name_inner: + kv_cache[layer_name_inner] = state_tensors + elif "attn" in layer_name and layer_name not in kv_cache: + kv_cache_spec = layer_kv_cache_spec[layer_name] + assert isinstance(kv_cache_spec, AttentionSpec) + assert kv_cache_tensor.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks + # Page attention operation on 310P limits block_size * head_size <= 128 * 128 + supported_sizes = [ + support_size + for support_size in self.attn_backend.get_supported_kernel_block_sizes() + if support_size * kv_cache_spec.head_size <= _ATTENTION_BLOCK_SIZE_LIMIT + ] + if supported_sizes: + block_size = supported_sizes[0] block_size_chunk = kv_cache_spec.block_size // block_size - kv_cache_shape = attn_backend.get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, kv_cache_spec.num_kv_heads, @@ -299,43 +291,27 @@ class NPUModelRunner310(NPUModelRunner): kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size ) - dtype = kv_cache_spec.dtype k_shape = kv_cache_shape[1:] v_shape = k_shape + dtype = kv_cache_spec.dtype k_cache = torch_npu.empty_with_format( size=k_shape, dtype=dtype, device=self.device, acl_format=self._acl_format ) v_cache = torch_npu.empty_with_format( size=v_shape, dtype=dtype, device=self.device, acl_format=self._acl_format ) - kv_caches[layer_name] = (k_cache, v_cache) - elif isinstance(kv_cache_spec, MambaSpec): - tensor_size = kv_cache_sizes[layer_name] - dtype = kv_cache_spec.dtype - tensor_element_size = torch.tensor([], dtype=dtype).element_size() - raw_tensor = torch.zeros(tensor_size // tensor_element_size, dtype=dtype, device=self.device) - assert tensor_size is not None - assert tensor_size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_size // kv_cache_spec.page_size_bytes - assert num_blocks >= kv_cache_config.num_blocks - - state_tensors = [] - target_idx = 0 - start_idx = 0 - for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): - # normally, there is conv state and ssm state in this loop. And there is only - # a conv state in some special models. - target_shape = (num_blocks, *shape) - - target_idx += torch.prod(torch.tensor(target_shape)).item() - tensor = raw_tensor[start_idx:target_idx].view(target_shape) - start_idx = target_idx - state_tensors.append(tensor) - kv_caches[layer_name] = state_tensors - else: - raise ValueError("Unknown KV cache spec type.") - - return kv_caches + for layer_name_inner in kv_cache_tensor.shared_by: + # shared the kvcache between the self_attn specs in the same group + if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: + kv_cache[layer_name_inner] = (k_cache, v_cache) + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache.keys()), "Some layers are not correctly initialized" + return kv_cache # Override this function because of tensor.copy_(other) accuracy issue. # TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved. @@ -430,3 +406,78 @@ class NPUModelRunner310(NPUModelRunner): index=draft_tokens_index_tensor, src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], ) + + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) + ] + + # Generate kernel_block_sizes that matches each block_size + # For attention backends that support virtual block splitting, + # use the supported block sizes from the backend + # For other backends (like Mamba), use [0] (no splitting) + self.kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_spec, AttentionSpec): + try: + attn_groups = self.attn_groups[kv_cache_group_id] + backend = attn_groups[0].backend + # Page attention operation on 310P limits block_size * head_size <= 128 * 128 + supported_sizes = [ + support_size + for support_size in backend.get_supported_kernel_block_sizes() + if support_size * kv_cache_spec.head_size <= _ATTENTION_BLOCK_SIZE_LIMIT + ] + kernel_block_size_list = supported_sizes if supported_sizes else [self.cache_config.block_size] + except IndexError: + kernel_block_size_list = [self.cache_config.block_size] + self.kernel_block_sizes.append(kernel_block_size_list) + else: + self.kernel_block_sizes.append([0]) + + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: + if vllm_version_is("0.17.0"): + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details." + ) + else: + assert self.offload_config.uva.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details." + ) + self.input_batch = NPUInputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config + else 0 + ), + kernel_block_sizes=self.kernel_block_sizes, + ) diff --git a/vllm_ascend/_310p/worker_310p.py b/vllm_ascend/_310p/worker_310p.py index 87fe3a81..5705a8e3 100644 --- a/vllm_ascend/_310p/worker_310p.py +++ b/vllm_ascend/_310p/worker_310p.py @@ -14,8 +14,11 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +import torch import torch_npu from vllm.logger import logger +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import memory_profiling from vllm_ascend._310p.model_runner_310p import NPUModelRunner310 from vllm_ascend.worker.worker import NPUWorker, init_workspace_manager @@ -47,6 +50,56 @@ class NPUWorker310(NPUWorker): ShardedStateLoader310.generate_quant_description(self.model_runner.model, path) + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculates the free memory that can be used for KV cache in + bytes. + """ + GiB = lambda b: b / GiB_bytes + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: + self.model_runner.profile_run() + free_memory, total_memory = torch.npu.mem_get_info() + torch_memory = torch.npu.memory_reserved() + non_torch_memory_before_empty_cache = total_memory - free_memory - torch_memory + + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + non_torch_memory_cleared_by_empty_cache = non_torch_memory_before_empty_cache - self.non_torch_memory + + free_gpu_memory = profile_result.after_profile.free_memory + assert self.init_snapshot.free_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " + f"current free memory {GiB(free_gpu_memory)} GiB. " + "This happens when other processes sharing the same container " + "release GPU memory while vLLM is profiling during initialization. " + "To fix this, ensure consistent GPU memory allocation or " + "isolate vLLM in its own container." + ) + + # Divide the available memory by 2, to reserved more memory for other operators workspace and other cache + # This could avoid OOM with default gpu_memory_utilization + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache + ) // 2 + + logger.debug(profile_result) + logger.info_once( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + scope="local", + ) + return int(self.available_kv_cache_memory_bytes) + def _warm_up_atb(self): # 310p device do not support torch_npu._npu_matmul_add_fp32 atb ops logger.info("Skip warm-up atb ops for 310P device.") diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 6bda63f0..1e4e5b49 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -19,7 +19,12 @@ import os import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa -import vllm_ascend.patch.platform.patch_mamba_config # noqa +from vllm_ascend.utils import is_310p + +if not is_310p(): + import vllm_ascend.patch.platform.patch_mamba_config # noqa +else: + import vllm_ascend.patch.platform.patch_mamba_config_310 # noqa import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_torch_accelerator # noqa diff --git a/vllm_ascend/patch/platform/patch_mamba_config_310.py b/vllm_ascend/patch/platform/patch_mamba_config_310.py new file mode 100644 index 00000000..db9775f3 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_mamba_config_310.py @@ -0,0 +1,104 @@ +# mypy: ignore-errors +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import lcm + +import vllm.model_executor.models.config +from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.config import MambaModelConfig +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + +@classmethod +def verify_and_update_config(cls, vllm_config) -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + logger = init_logger(__name__) + # Save the user input before it gets modified by MambaModelConfig + mamba_block_size = vllm_config.cache_config.mamba_block_size + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + if model_config.use_mla: + raise RuntimeError("MLA is not supported on 310P currently.") + kernel_block_alignment_size = 128 + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=-1, + ).page_size_bytes + + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + if cache_config.mamba_cache_mode == "all": + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token + ) + if cache_config.block_size is None or cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + attn_page_size = cache_config.block_size * attn_page_size_1_token + assert attn_page_size >= mamba_page_size + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + # pad mamba page size to exactly match attention + if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size: + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, + ) + + +vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config