Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_mamba_config.py
Icey d9cdc65854 Upgrade to new vllm commit (#3719)
### What this PR does / why we need it?
Upgrade to new vllm commit:
c9461e05a4

- Fix many imports, caused by
https://github.com/vllm-project/vllm/pull/26908
- Fix import ```sha256```, caused by
https://github.com/vllm-project/vllm/pull/27169
- Remove ```SchedulerConfig.send_delta_data```, caused by
https://github.com/vllm-project/vllm/pull/27142
- Fix ```FusedMoE``` because of dual stream execution, caused by
https://github.com/vllm-project/vllm/pull/26440

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
2025-10-25 15:36:32 +08:00

104 lines
3.8 KiB
Python

# mypy: ignore-errors
import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
else:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
logger = init_logger(__name__)
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes
block_alignment_bytes = 64
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = block_alignment_bytes * cdiv(
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if (cache_config.block_size is None
or cache_config.block_size < attn_block_size):
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size)
# compute new attention page size
attn_page_size = \
cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size):
cache_config.mamba_page_size_padded = (attn_page_size)
mamba_padding_pct = 100 * (attn_page_size -
mamba_page_size) / mamba_page_size
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct)
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config