193 lines
7.6 KiB
Python
193 lines
7.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from math import lcm
|
|
from typing import TYPE_CHECKING
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.math_utils import cdiv, round_up
|
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig,
|
|
MambaModelConfig)
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
@classmethod
|
|
def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config(
|
|
cls,
|
|
vllm_config: "VllmConfig"
|
|
) -> None:
|
|
"""
|
|
Ensure that page size of attention layers is greater than or
|
|
equal to the mamba layers. If not, automatically set the attention
|
|
block size to ensure that it is. If the attention page size is
|
|
strictly greater than the mamba page size, we pad the mamba page size
|
|
to make them equal.
|
|
|
|
Args:
|
|
vllm_config: vLLM Config
|
|
"""
|
|
# Save the user input before it gets modified by MambaModelConfig
|
|
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
|
# Enable FULL_AND_PIECEWISE by default
|
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
|
|
|
cache_config = vllm_config.cache_config
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
kv_cache_dtype = model_config.dtype
|
|
else:
|
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
# get attention page size (for 1 token)
|
|
# Attention backend constraints:
|
|
# - FlashAttention (FA) requires block size to be multiple of 16
|
|
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
|
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
|
# * Other MLA backends: kernel_block_size 64 alignment
|
|
if model_config.use_mla:
|
|
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
|
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
|
attn_page_size_1_token = MLAAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype,
|
|
).page_size_bytes
|
|
else:
|
|
kernel_block_alignment_size = 16
|
|
if (
|
|
current_platform.is_device_capability(100)
|
|
and model_config.get_head_size() == 256
|
|
and (
|
|
envs.VLLM_ATTENTION_BACKEND is None
|
|
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
|
)
|
|
):
|
|
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
|
# head size 256 and block size 16 is not supported on blackwell.
|
|
kernel_block_alignment_size = 32
|
|
attn_page_size_1_token = FullAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype,
|
|
).page_size_bytes
|
|
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
|
model_config.architecture,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# get mamba page size
|
|
mamba_page_size = MambaSpec(
|
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
|
block_size=model_config.max_model_len,
|
|
).page_size_bytes
|
|
|
|
# Model may be marked as is_hybrid
|
|
# but mamba is skipped via config,
|
|
# return directly
|
|
if mamba_page_size == 0:
|
|
return
|
|
|
|
if cache_config.enable_prefix_caching:
|
|
# With prefix caching, select attention block size to
|
|
# optimize for mamba kernel performance
|
|
|
|
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
|
# Align the block to the kernel: use lowest multiple of chunk_size
|
|
# of attention tokens that would fit mamba_page_size:
|
|
# e.g. for mamba page size = 788kB
|
|
# attn_1_token = 2kB -> fits ~394 tokens
|
|
# then round up to a mulitple of 256 -> 512 tokens
|
|
# End result:
|
|
# attn_block_size = 512
|
|
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
|
# TODO(tdoublep): this constraint can be relaxed fairly
|
|
# easily by changing the way we layout chunks in the
|
|
# mamba2 kernels.
|
|
|
|
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
|
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
|
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
|
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
|
cache_config.mamba_block_size = attn_block_size
|
|
else:
|
|
# Without prefix caching, select minimum valid attention block size
|
|
# to minimize mamba state padding
|
|
|
|
# Calculate minimum attention block size that satisfies both:
|
|
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
|
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
|
attn_block_size = kernel_block_alignment_size * cdiv(
|
|
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
|
)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: support qwen3-next
|
|
'''
|
|
if (vllm_config.mlu_config.enable_mamba_split_page_size):
|
|
vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size)
|
|
cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token
|
|
return
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
# override attention block size if either (a) the
|
|
# user has not set it or (b) the user has set it
|
|
# too small.
|
|
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
|
cache_config.block_size = attn_block_size
|
|
logger.info(
|
|
"Setting attention block size to %d tokens "
|
|
"to ensure that attention page size is >= mamba page size.",
|
|
attn_block_size,
|
|
)
|
|
|
|
# compute new attention page size
|
|
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
|
|
|
assert attn_page_size >= mamba_page_size
|
|
|
|
if attn_page_size == mamba_page_size:
|
|
# don't need to pad mamba page size
|
|
return
|
|
|
|
# pad mamba page size to exactly match attention
|
|
if (
|
|
cache_config.mamba_page_size_padded is None
|
|
or cache_config.mamba_page_size_padded != attn_page_size
|
|
):
|
|
cache_config.mamba_page_size_padded = attn_page_size
|
|
mamba_padding_pct = (
|
|
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
|
)
|
|
logger.info(
|
|
"Padding mamba page size by %.2f%% to ensure "
|
|
"that mamba page size and attention page size are "
|
|
"exactly equal.",
|
|
mamba_padding_pct,
|
|
)
|
|
|
|
MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig,
|
|
HybridAttentionMambaModelConfig.verify_and_update_config,
|
|
vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config)
|