### What this PR does / why we need it?
1. Mamba Cache Support on 310P: Implemented logic to correctly
initialize and allocate KV cache for Mamba models on the 310P platform,
including handling of state tensors and page size alignment.
2. Increased Attention Head Size Support: Modified the attention backend
to support attn_head_size larger than 128 by dynamically selecting
appropriate kernel block sizes based on hardware limitations (e.g.,
block_size * head_size <= 16384).
3. Refactored KV Cache Allocation: Consolidated and improved the KV
cache allocation mechanism, moving from separate size calculation and
allocation steps to a unified _allocate_kv_cache_tensors method that
handles both Attention and Mamba specific cache structures.
4. Dynamic Mamba Config Patching: Introduced conditional loading of
Mamba configuration patches, specifically using patch_mamba_config_310
for the 310P platform to ensure platform-specific optimizations and
validations.
5. Reserve reasonable memory to allocate KV cache to avoid OOM issue
with default gpu_memory_utilization.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3.5 E2E test
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
105 lines
4.3 KiB
Python
105 lines
4.3 KiB
Python
# mypy: ignore-errors
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from math import lcm
|
|
|
|
import vllm.model_executor.models.config
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.model_executor.models.config import MambaModelConfig
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
|
|
|
|
|
@classmethod
|
|
def verify_and_update_config(cls, vllm_config) -> None:
|
|
"""
|
|
Ensure that page size of attention layers is greater than or
|
|
equal to the mamba layers. If not, automatically set the attention
|
|
block size to ensure that it is. If the attention page size is
|
|
strictly greater than the mamba page size, we pad the mamba page size
|
|
to make them equal.
|
|
|
|
Args:
|
|
vllm_config: vLLM Config
|
|
"""
|
|
logger = init_logger(__name__)
|
|
# Save the user input before it gets modified by MambaModelConfig
|
|
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
|
# Enable FULL_AND_PIECEWISE by default
|
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
|
cache_config = vllm_config.cache_config
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
kv_cache_dtype = model_config.dtype
|
|
else:
|
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
# get attention page size (for 1 token)
|
|
if model_config.use_mla:
|
|
raise RuntimeError("MLA is not supported on 310P currently.")
|
|
kernel_block_alignment_size = 128
|
|
attn_page_size_1_token = FullAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype,
|
|
).page_size_bytes
|
|
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
|
model_config.architecture,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# get mamba page size
|
|
mamba_page_size = MambaSpec(
|
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
|
block_size=-1,
|
|
).page_size_bytes
|
|
|
|
# Model may be marked as is_hybrid
|
|
# but mamba is skipped via config,
|
|
# return directly
|
|
if mamba_page_size == 0:
|
|
return
|
|
if cache_config.mamba_cache_mode == "all":
|
|
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
|
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
|
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
|
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
|
cache_config.mamba_block_size = attn_block_size
|
|
else:
|
|
attn_block_size = kernel_block_alignment_size * cdiv(
|
|
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
|
)
|
|
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
|
cache_config.block_size = attn_block_size
|
|
logger.info(
|
|
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
|
|
attn_block_size,
|
|
)
|
|
if cache_config.mamba_cache_mode == "align":
|
|
cache_config.mamba_block_size = cache_config.block_size
|
|
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
|
assert attn_page_size >= mamba_page_size
|
|
if attn_page_size == mamba_page_size:
|
|
# don't need to pad mamba page size
|
|
return
|
|
# pad mamba page size to exactly match attention
|
|
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
|
|
cache_config.mamba_page_size_padded = attn_page_size
|
|
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
|
logger.info(
|
|
"Padding mamba page size by %.2f%% to ensure "
|
|
"that mamba page size and attention page size are "
|
|
"exactly equal.",
|
|
mamba_padding_pct,
|
|
)
|
|
|
|
|
|
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|