[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)

### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
drslark
2025-12-10 22:54:24 +08:00
committed by GitHub
parent 490ddf536f
commit 0fb1dc43a1
8 changed files with 646 additions and 28 deletions

View File

@@ -129,3 +129,28 @@
# Future Plan:
# Remove this patch when adapted vllm version contains the above PR.
#
# ** File: worker/patch_qwen3_next_mtp.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.utils.bind_kv_cache`
# Why:
# 'bind_kv_cache' func will raise an exception when current_platform is npu.
# How
# Replace with a new bind_kv_cache.
# Skip the raise.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
# ** File: worker/patch_module.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
# 'torch.argsort' func of npu does not support bool.
# How
# Replace with a new torch.argsort that will cast the input to torch.int32.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# Future Plan:
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
#

View File

@@ -32,3 +32,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa

View File

@@ -0,0 +1,34 @@
import torch
# torch_npu.argsort does not sipport bool now, it will support it in the future.
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)
class _TorchWrapper:
def __init__(self):
self._raw_torch = torch
def __getattr__(self, name):
if name == "argsort":
return _argsort
else:
return getattr(self._raw_torch, name)
_is_patched = False
# patch argsort only for torch in gdn_attn
def patch_torch_npu_argsort():
global _is_patched
if not _is_patched:
import vllm.v1.attention.backends.gdn_attn as gdn_attn
gdn_attn.torch = _TorchWrapper()
_is_patched = True

View File

@@ -0,0 +1,52 @@
import torch
import vllm.v1.worker.utils as utils
from vllm.attention.layer import Attention
from vllm.v1.worker.utils import defaultdict, extract_layer_index
# Without this patch, it will raise an exception when initialize kv_cache.
# TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError.
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
# remove some codes for the typical case of encoder-decoder model, e.g., bart.
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
utils.bind_kv_cache = bind_kv_cache