[KVCache]Qwen3.5 supports contiguous tensor hybrid-attn kv-cache (#6887)
### What this PR does / why we need it?
Supports contiguous tensor hybrid-attn kv-cache on fullattn-mamba hybrid
model, such as Qwen3Next and Qwen3.5.
Due to the restrictions of Ascend operators, all KV tensors, conv
tensors, and SSM tensors must be contiguous. Therefore, this PR uses the
following solution to generate the KV cache:
tensor1: [(kv_padding), conv , ...]
tensor2: [k , ssm , ...]
tensor3: [v , (mamba_padding), ...]
Under this scheme, although some waste may occur, the tensors of all
caches are guaranteed to be contiguous.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
# mypy: ignore-errors
|
||||
import math
|
||||
|
||||
import vllm.model_executor.models.config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.config import MambaModelConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
|
||||
|
||||
@classmethod
|
||||
@@ -33,33 +34,32 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
kernel_block_size = 128
|
||||
# get attention block size
|
||||
attn_num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
attn_head_size = model_config.get_head_size()
|
||||
attn_single_token_k_page_size = attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype)
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
# get mamba block size
|
||||
mamba_shapes = model_cls.get_mamba_state_shape_from_config(vllm_config)
|
||||
mamba_dtypes = model_cls.get_mamba_state_dtype_from_config(vllm_config)
|
||||
mamba_sizes = []
|
||||
for shape, dtype in zip(mamba_shapes, mamba_dtypes):
|
||||
mamba_sizes.append(math.prod(shape) * get_dtype_size(dtype))
|
||||
ssm_block_page_size, conv_block_page_size = max(mamba_sizes), min(mamba_sizes)
|
||||
|
||||
block_alignment_bytes = 128
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = block_alignment_bytes * cdiv(mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
||||
# NOTE(zxr): because of the limit of Ascend Hardware, we need to keep
|
||||
# all cache tensors contiguous, so we align the page size of ssm_block
|
||||
# and single attn_block
|
||||
attn_block_size = kernel_block_size * cdiv(ssm_block_page_size, kernel_block_size * attn_single_token_k_page_size)
|
||||
assert attn_single_token_k_page_size * attn_block_size == ssm_block_page_size, (
|
||||
"Cannot align ssm_page_size and attn_page_size."
|
||||
)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
@@ -72,24 +72,25 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
||||
)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
attn_page_size = cache_config.block_size * 2 * attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype)
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
# pad mamba page size for conv_blocks
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size + conv_block_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size + conv_block_page_size
|
||||
mamba_padding_pct = 100 * conv_block_page_size / cache_config.mamba_page_size_padded
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
|
||||
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|
||||
|
||||
Reference in New Issue
Block a user