[Refactor] [310p] Support Mamba Cache and support attn_head_size larger than 128 (#7372)

### What this PR does / why we need it?
1. Mamba Cache Support on 310P: Implemented logic to correctly
initialize and allocate KV cache for Mamba models on the 310P platform,
including handling of state tensors and page size alignment.
2. Increased Attention Head Size Support: Modified the attention backend
to support attn_head_size larger than 128 by dynamically selecting
appropriate kernel block sizes based on hardware limitations (e.g.,
block_size * head_size <= 16384).
3. Refactored KV Cache Allocation: Consolidated and improved the KV
cache allocation mechanism, moving from separate size calculation and
allocation steps to a unified _allocate_kv_cache_tensors method that
handles both Attention and Mamba specific cache structures.
4. Dynamic Mamba Config Patching: Introduced conditional loading of
Mamba configuration patches, specifically using patch_mamba_config_310
for the 310P platform to ensure platform-specific optimizations and
validations.
5. Reserve reasonable memory to allocate KV cache to avoid OOM issue
with default gpu_memory_utilization.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3.5 E2E test
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-03-19 09:16:22 +08:00
committed by GitHub
parent 8b79d4de52
commit e8f7b2e3f1
5 changed files with 314 additions and 97 deletions

View File

@@ -78,6 +78,10 @@ class AscendAttentionBackend310(AscendAttentionBackend):
"""
return AscendAttentionMetadataBuilder310
@staticmethod
def get_supported_kernel_block_sizes() -> list[int]:
return [128, 64]
class AscendAttentionBackendImpl310(AscendAttentionBackendImpl):
"""

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import math
from contextlib import contextmanager, nullcontext
import numpy as np
@@ -24,14 +25,24 @@ import torch
import torch_npu
from vllm.config import CUDAGraphMode
from vllm.logger import logger
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
KVCacheConfig,
KVCacheSpec,
MambaSpec,
UniformTypeKVCacheSpecs,
)
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
_NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN = 1
_ATTENTION_BLOCK_SIZE_LIMIT = 128 * 128
class NPUModelRunner310(NPUModelRunner):
@@ -184,6 +195,7 @@ class NPUModelRunner310(NPUModelRunner):
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Override the base class method.
Initialize the memory buffer for KV cache.
Args:
@@ -199,10 +211,8 @@ class NPUModelRunner310(NPUModelRunner):
raise ValueError("Deepseek Sparse Attention is not supported for 310P.")
if self.model_config.use_mla:
raise ValueError("MLAAttention is not supported for 310P.")
# Initialize the memory size for KV cache
kv_cache_size = self._calculate_kv_cache_tensors_size(kv_cache_config)
# Allocate and reshape KV cache Tensors
kv_caches = self._allocate_kv_cache_and_reshape_tensors(kv_cache_config, kv_cache_size)
# Initialize the memory buffer for KV cache
kv_caches = self._allocate_kv_cache_tensors(kv_cache_config)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
@@ -213,7 +223,7 @@ class NPUModelRunner310(NPUModelRunner):
bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches)
return kv_caches
def _calculate_kv_cache_tensors_size(self, kv_cache_config: KVCacheConfig) -> dict[str, int]:
def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache size. The buffer needs to be reshaped to the desired shape before being used by
the models.
@@ -221,75 +231,57 @@ class NPUModelRunner310(NPUModelRunner):
Args:
kv_cache_config: The KV cache config
Returns:
dict[str, int]: A map between layer names to their
corresponding memory buffer size.
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer.
"""
# init kv cache tensors
kv_cache_sizes: dict[str, int] = {}
kv_cache: dict[str, list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]] = {}
# get kv cache spec for each layer
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
for group_kv_cache_spec in kv_cache_config.kv_cache_groups:
for layer_name in group_kv_cache_spec.layer_names:
layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec
# Allocate kv cache buffers according to the kv_cache_config and kv_cache_spec
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# TODO: REFACTOR ME to sharing hybrid cache
for idx in range(len(kv_cache_tensor.shared_by)):
layer_name = kv_cache_tensor.shared_by[idx]
if "linear_attn" in layer_name and layer_name not in kv_cache_sizes:
# for mamba linear attention
kv_cache_size = kv_cache_tensor.size
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "linear_attn" in layer_name_inner:
kv_cache_sizes[layer_name_inner] = kv_cache_size
elif "attn" in layer_name and layer_name not in kv_cache_sizes:
kv_tensor_split_factor = 2
kv_tensor_size = int(kv_cache_tensor.size // kv_tensor_split_factor)
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
kv_cache_sizes[layer_name_inner] = kv_tensor_size
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_sizes.keys()), "Some layers are not correctly initialized"
return kv_cache_sizes
def _allocate_kv_cache_and_reshape_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_sizes: dict[str, int],
) -> dict[str, torch.Tensor]:
"""
Allocate the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_sizes: The KV cache size of each layer
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
if isinstance(kv_cache_spec, AttentionSpec):
kv_tensor_size = kv_cache_sizes[layer_name]
assert kv_tensor_size is not None
sum_page_size_bytes = kv_tensor_size * 2
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
if "linear_attn" in layer_name and layer_name not in kv_cache:
cache_spec = layer_kv_cache_spec[layer_name]
assert isinstance(cache_spec, MambaSpec)
assert kv_cache_tensor.size % cache_spec.page_size_bytes == 0
num_blocks = kv_cache_tensor.size // cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_kernel_block_sizes()[0]
raw_tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device)
state_tensors = []
target_idx = 0
start_idx = 0
for shape, dtype in zip(cache_spec.shapes, cache_spec.dtypes):
target_shape = (num_blocks, *shape)
target_idx += math.prod(target_shape) * get_dtype_size(dtype)
tensor = raw_tensor[start_idx:target_idx].view(dtype).view(target_shape)
start_idx = target_idx
state_tensors.append(tensor)
for layer_name_inner in kv_cache_tensor.shared_by:
if "linear_attn" in layer_name_inner:
kv_cache[layer_name_inner] = state_tensors
elif "attn" in layer_name and layer_name not in kv_cache:
kv_cache_spec = layer_kv_cache_spec[layer_name]
assert isinstance(kv_cache_spec, AttentionSpec)
assert kv_cache_tensor.size % kv_cache_spec.page_size_bytes == 0
num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
# Page attention operation on 310P limits block_size * head_size <= 128 * 128
supported_sizes = [
support_size
for support_size in self.attn_backend.get_supported_kernel_block_sizes()
if support_size * kv_cache_spec.head_size <= _ATTENTION_BLOCK_SIZE_LIMIT
]
if supported_sizes:
block_size = supported_sizes[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk,
block_size,
kv_cache_spec.num_kv_heads,
@@ -299,43 +291,27 @@ class NPUModelRunner310(NPUModelRunner):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
)
dtype = kv_cache_spec.dtype
k_shape = kv_cache_shape[1:]
v_shape = k_shape
dtype = kv_cache_spec.dtype
k_cache = torch_npu.empty_with_format(
size=k_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
)
v_cache = torch_npu.empty_with_format(
size=v_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
)
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
tensor_size = kv_cache_sizes[layer_name]
dtype = kv_cache_spec.dtype
tensor_element_size = torch.tensor([], dtype=dtype).element_size()
raw_tensor = torch.zeros(tensor_size // tensor_element_size, dtype=dtype, device=self.device)
assert tensor_size is not None
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
state_tensors = []
target_idx = 0
start_idx = 0
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
# normally, there is conv state and ssm state in this loop. And there is only
# a conv state in some special models.
target_shape = (num_blocks, *shape)
target_idx += torch.prod(torch.tensor(target_shape)).item()
tensor = raw_tensor[start_idx:target_idx].view(target_shape)
start_idx = target_idx
state_tensors.append(tensor)
kv_caches[layer_name] = state_tensors
else:
raise ValueError("Unknown KV cache spec type.")
return kv_caches
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
kv_cache[layer_name_inner] = (k_cache, v_cache)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache.keys()), "Some layers are not correctly initialized"
return kv_cache
# Override this function because of tensor.copy_(other) accuracy issue.
# TODO: This override will be removed after tensor.copy_(other) accuracy issue is resolved.
@@ -430,3 +406,78 @@ class NPUModelRunner310(NPUModelRunner):
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
# Generate kernel_block_sizes that matches each block_size
# For attention backends that support virtual block splitting,
# use the supported block sizes from the backend
# For other backends (like Mamba), use [0] (no splitting)
self.kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_spec, AttentionSpec):
try:
attn_groups = self.attn_groups[kv_cache_group_id]
backend = attn_groups[0].backend
# Page attention operation on 310P limits block_size * head_size <= 128 * 128
supported_sizes = [
support_size
for support_size in backend.get_supported_kernel_block_sizes()
if support_size * kv_cache_spec.head_size <= _ATTENTION_BLOCK_SIZE_LIMIT
]
kernel_block_size_list = supported_sizes if supported_sizes else [self.cache_config.block_size]
except IndexError:
kernel_block_size_list = [self.cache_config.block_size]
self.kernel_block_sizes.append(kernel_block_size_list)
else:
self.kernel_block_sizes.append([0])
if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]:
if vllm_version_is("0.17.0"):
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details."
)
else:
assert self.offload_config.uva.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details."
)
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.model_config.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config
else 0
),
kernel_block_sizes=self.kernel_block_sizes,
)

View File

@@ -14,8 +14,11 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch_npu
from vllm.logger import logger
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import memory_profiling
from vllm_ascend._310p.model_runner_310p import NPUModelRunner310
from vllm_ascend.worker.worker import NPUWorker, init_workspace_manager
@@ -47,6 +50,56 @@ class NPUWorker310(NPUWorker):
ShardedStateLoader310.generate_quant_description(self.model_runner.model, path)
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the free memory that can be used for KV cache in
bytes.
"""
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()
free_memory, total_memory = torch.npu.mem_get_info()
torch_memory = torch.npu.memory_reserved()
non_torch_memory_before_empty_cache = total_memory - free_memory - torch_memory
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
non_torch_memory_cleared_by_empty_cache = non_torch_memory_before_empty_cache - self.non_torch_memory
free_gpu_memory = profile_result.after_profile.free_memory
assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f"current free memory {GiB(free_gpu_memory)} GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container."
)
# Divide the available memory by 2, to reserved more memory for other operators workspace and other cache
# This could avoid OOM with default gpu_memory_utilization
self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
) // 2
logger.debug(profile_result)
logger.info_once(
"Available KV cache memory: %.2f GiB",
GiB(self.available_kv_cache_memory_bytes),
scope="local",
)
return int(self.available_kv_cache_memory_bytes)
def _warm_up_atb(self):
# 310p device do not support torch_npu._npu_matmul_add_fp32 atb ops
logger.info("Skip warm-up atb ops for 310P device.")

View File

@@ -19,7 +19,12 @@ import os
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
import vllm_ascend.patch.platform.patch_kv_cache_interface # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
from vllm_ascend.utils import is_310p
if not is_310p():
import vllm_ascend.patch.platform.patch_mamba_config # noqa
else:
import vllm_ascend.patch.platform.patch_mamba_config_310 # noqa
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.platform.patch_torch_accelerator # noqa

View File

@@ -0,0 +1,104 @@
# mypy: ignore-errors
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import lcm
import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
logger = init_logger(__name__)
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
if model_config.use_mla:
raise RuntimeError("MLA is not supported on 310P currently.")
kernel_block_alignment_size = 128
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=-1,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.mamba_cache_mode == "all":
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config