[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.
It still has bugs for batched chunked prefill.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -129,3 +129,28 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when adapted vllm version contains the above PR.
|
||||
#
|
||||
# ** File: worker/patch_qwen3_next_mtp.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
||||
# Why:
|
||||
# 'bind_kv_cache' func will raise an exception when current_platform is npu.
|
||||
# How:
|
||||
# Replace with a new bind_kv_cache.
|
||||
# Skip the raise.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/4770
|
||||
# Future Plan:
|
||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||
#
|
||||
# ** File: worker/patch_module.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||
# Why:
|
||||
# 'torch.argsort' func of npu does not support bool.
|
||||
# How:
|
||||
# Replace with a new torch.argsort that will cast the input to torch.int32.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/4770
|
||||
# Future Plan:
|
||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||
#
|
||||
|
||||
@@ -32,3 +32,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_rope # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||
|
||||
34
vllm_ascend/patch/worker/patch_module.py
Normal file
34
vllm_ascend/patch/worker/patch_module.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
# torch_npu.argsort does not sipport bool now, it will support it in the future.
|
||||
# TODO When the operator of argsort is ready, this patch must be removed.
|
||||
def _argsort(tensor, *args, **kwargs):
|
||||
if tensor.dtype == torch.bool:
|
||||
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
|
||||
else:
|
||||
return torch.argsort(tensor, *args, **kwargs)
|
||||
|
||||
|
||||
class _TorchWrapper:
|
||||
|
||||
def __init__(self):
|
||||
self._raw_torch = torch
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "argsort":
|
||||
return _argsort
|
||||
else:
|
||||
return getattr(self._raw_torch, name)
|
||||
|
||||
|
||||
_is_patched = False
|
||||
|
||||
|
||||
# patch argsort only for torch in gdn_attn
|
||||
def patch_torch_npu_argsort():
|
||||
global _is_patched
|
||||
if not _is_patched:
|
||||
import vllm.v1.attention.backends.gdn_attn as gdn_attn
|
||||
gdn_attn.torch = _TorchWrapper()
|
||||
_is_patched = True
|
||||
52
vllm_ascend/patch/worker/patch_qwen3_next_mtp.py
Normal file
52
vllm_ascend/patch/worker/patch_qwen3_next_mtp.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import vllm.v1.worker.utils as utils
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.worker.utils import defaultdict, extract_layer_index
|
||||
|
||||
|
||||
# Without this patch, it will raise an exception when initialize kv_cache.
|
||||
# TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError.
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, Attention],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
# remove some codes for the typical case of encoder-decoder model, e.g., bart.
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
utils.bind_kv_cache = bind_kv_cache
|
||||
Reference in New Issue
Block a user