### What this PR does / why we need it?
See https://github.com/vllm-project/vllm-ascend/pull/7402, pre-commit
hook will forbid init_logger(__name__) in vllm_ascend patch modules
- vLLM version: v0.17.0
- vLLM main:
8a680463fa
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
104 lines
4.3 KiB
Python
104 lines
4.3 KiB
Python
# mypy: ignore-errors
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from math import lcm
|
|
|
|
import vllm.model_executor.models.config
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.model_executor.models.config import MambaModelConfig
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
|
|
|
|
|
@classmethod
|
|
def verify_and_update_config(cls, vllm_config) -> None:
|
|
"""
|
|
Ensure that page size of attention layers is greater than or
|
|
equal to the mamba layers. If not, automatically set the attention
|
|
block size to ensure that it is. If the attention page size is
|
|
strictly greater than the mamba page size, we pad the mamba page size
|
|
to make them equal.
|
|
|
|
Args:
|
|
vllm_config: vLLM Config
|
|
"""
|
|
# Save the user input before it gets modified by MambaModelConfig
|
|
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
|
# Enable FULL_AND_PIECEWISE by default
|
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
|
cache_config = vllm_config.cache_config
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
kv_cache_dtype = model_config.dtype
|
|
else:
|
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
# get attention page size (for 1 token)
|
|
if model_config.use_mla:
|
|
raise RuntimeError("MLA is not supported on 310P currently.")
|
|
kernel_block_alignment_size = 128
|
|
attn_page_size_1_token = FullAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype,
|
|
).page_size_bytes
|
|
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
|
model_config.architecture,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# get mamba page size
|
|
mamba_page_size = MambaSpec(
|
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
|
block_size=-1,
|
|
).page_size_bytes
|
|
|
|
# Model may be marked as is_hybrid
|
|
# but mamba is skipped via config,
|
|
# return directly
|
|
if mamba_page_size == 0:
|
|
return
|
|
if cache_config.mamba_cache_mode == "all":
|
|
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
|
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
|
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
|
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
|
cache_config.mamba_block_size = attn_block_size
|
|
else:
|
|
attn_block_size = kernel_block_alignment_size * cdiv(
|
|
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
|
)
|
|
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
|
cache_config.block_size = attn_block_size
|
|
logger.info(
|
|
"Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.",
|
|
attn_block_size,
|
|
)
|
|
if cache_config.mamba_cache_mode == "align":
|
|
cache_config.mamba_block_size = cache_config.block_size
|
|
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
|
assert attn_page_size >= mamba_page_size
|
|
if attn_page_size == mamba_page_size:
|
|
# don't need to pad mamba page size
|
|
return
|
|
# pad mamba page size to exactly match attention
|
|
if cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size:
|
|
cache_config.mamba_page_size_padded = attn_page_size
|
|
mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
|
logger.info(
|
|
"Padding mamba page size by %.2f%% to ensure "
|
|
"that mamba page size and attention page size are "
|
|
"exactly equal.",
|
|
mamba_padding_pct,
|
|
)
|
|
|
|
|
|
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|