Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_qwen3_next_mtp.py
drslark 0fb1dc43a1 [BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 22:54:24 +08:00

53 lines
2.0 KiB
Python

import torch
import vllm.v1.worker.utils as utils
from vllm.attention.layer import Attention
from vllm.v1.worker.utils import defaultdict, extract_layer_index
# Without this patch, it will raise an exception when initialize kv_cache.
# TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError.
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
# remove some codes for the typical case of encoder-decoder model, e.g., bart.
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
utils.bind_kv_cache = bind_kv_cache