### What this PR does / why we need it?
Fixes a compatible bug with `torch_npu.npu_fused_infer_attention_score`
which is discribed in
https://github.com/vllm-project/vllm-ascend/issues/4020.
@momo609 tells us this solution.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
The environment is same with this issue,
https://github.com/vllm-project/vllm-ascend/issues/4020.
We modify the code according to
https://github.com/vllm-project/vllm-ascend/pull/3918.
And run below codes:
```python
# run with Qwen3-next-mtp
prompts = [
"Who are you?",
]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=128)
llm = LLM(model="/home/model/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1,
},
max_model_len=4096)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Outputs:
```text
Prompt: 'Who are you?', Generated text: ' I am Qwen, a large-scale language model independently developed by the Tongyi Lab under Alibaba Group. I am designed to answer questions, create text such as stories, official documents, emails, scripts, and more, as well as perform logical reasoning, programming, and other tasks. If you have any questions or need assistance, feel free to let me know anytime!'
```
Now, `torch_npu.npu_fused_infer_attention_score` is compatible with
Qwen3-Next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
104 lines
3.8 KiB
Python
104 lines
3.8 KiB
Python
# mypy: ignore-errors
|
|
import vllm.model_executor.models.config
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models import ModelRegistry
|
|
from vllm.model_executor.models.config import MambaModelConfig
|
|
from vllm.utils import cdiv
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
|
|
|
from vllm_ascend.utils import vllm_version_is
|
|
|
|
if vllm_version_is("0.11.0"):
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
else:
|
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
|
|
|
|
@classmethod
|
|
def verify_and_update_config(cls, vllm_config) -> None:
|
|
"""
|
|
Ensure that page size of attention layers is greater than or
|
|
equal to the mamba layers. If not, automatically set the attention
|
|
block size to ensure that it is. If the attention page size is
|
|
strictly greater than the mamba page size, we pad the mamba page size
|
|
to make them equal.
|
|
|
|
Args:
|
|
vllm_config: vLLM Config
|
|
"""
|
|
logger = init_logger(__name__)
|
|
# Enable FULL_AND_PIECEWISE by default
|
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
|
|
|
cache_config = vllm_config.cache_config
|
|
model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if cache_config.cache_dtype == "auto":
|
|
kv_cache_dtype = model_config.dtype
|
|
else:
|
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
# get attention page size (for 1 token)
|
|
attn_page_size_1_token = FullAttentionSpec(
|
|
block_size=1,
|
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
|
head_size=model_config.get_head_size(),
|
|
dtype=kv_cache_dtype).page_size_bytes
|
|
|
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
|
model_config.architecture,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# get mamba page size
|
|
mamba_page_size = MambaSpec(
|
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
|
block_size=model_config.max_model_len,
|
|
).page_size_bytes
|
|
|
|
block_alignment_bytes = 128
|
|
|
|
# some attention backends (e.g. FA) only support setting
|
|
# block size to multiple of 16, so let's suggest a value
|
|
# that would work (note: FA is currently not compatible
|
|
# with mamba layers, use FlashInfer instead).
|
|
attn_block_size = block_alignment_bytes * cdiv(
|
|
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
|
|
|
|
# override attention block size if either (a) the
|
|
# user has not set it or (b) the user has set it
|
|
# too small.
|
|
if (cache_config.block_size is None
|
|
or cache_config.block_size < attn_block_size):
|
|
cache_config.block_size = attn_block_size
|
|
logger.info(
|
|
"Setting attention block size to %d tokens "
|
|
"to ensure that attention page size is >= mamba page size.",
|
|
attn_block_size)
|
|
|
|
# compute new attention page size
|
|
attn_page_size = \
|
|
cache_config.block_size * attn_page_size_1_token
|
|
|
|
assert attn_page_size >= mamba_page_size
|
|
|
|
if attn_page_size == mamba_page_size:
|
|
# don't need to pad mamba page size
|
|
return
|
|
|
|
# pad mamba page size to exactly match attention
|
|
if (cache_config.mamba_page_size_padded is None
|
|
or cache_config.mamba_page_size_padded != attn_page_size):
|
|
cache_config.mamba_page_size_padded = (attn_page_size)
|
|
mamba_padding_pct = 100 * (attn_page_size -
|
|
mamba_page_size) / mamba_page_size
|
|
logger.info(
|
|
"Padding mamba page size by %.2f%% to ensure "
|
|
"that mamba page size and attention page size are "
|
|
"exactly equal.", mamba_padding_pct)
|
|
|
|
|
|
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config
|