[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)

### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
drslark
2025-12-10 22:54:24 +08:00
committed by GitHub
parent 490ddf536f
commit 0fb1dc43a1
8 changed files with 646 additions and 28 deletions

View File

@@ -24,7 +24,6 @@ Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
import os
from unittest.mock import patch
import pytest
from modelscope import snapshot_download # type: ignore
from tests.e2e.conftest import VllmRunner
@@ -64,14 +63,9 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
del vllm_model
@pytest.mark.skip
# TODO: Fix the accuary of batch chunked prefill
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
example_prompts = ["Hello, my name is"]
max_tokens = 20
with VllmRunner(
@@ -115,7 +109,6 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
# TODO: will conduct accuracy verification after the subsequent version becomes stable
@pytest.mark.skip
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
def test_models_distributed_Qwen3_NEXT_W8A8DYNAMIC_WITH_EP():
example_prompts = [

View File

@@ -132,6 +132,490 @@ def causal_conv1d_fn(
return out_ref_tensor
# TODO copied from vllm and it needs to be optimized
@triton.jit()
def _original_causal_conv1d_update_kernel(
# Pointers to matrices
x_ptr, # (batch, dim, seqlen)
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
dim: tl.constexpr,
seqlen: tl.constexpr,
state_len: tl.constexpr,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_APC_ENABLED:
# Get the state from the initial_state_idx
conv_state_init = tl.load(initial_state_idx + idx_seq)
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
else:
conv_state_init = 0
current_last_index = 0
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
conv_state_init).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
if IS_VARLEN:
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
tl.int64)
# revise state_len and seqlen
state_len = state_len - (seqlen -
(query_end_index - query_start_index))
seqlen = query_end_index - query_start_index
x_offset = query_start_index * stride_x_token
o_offset = query_start_index * stride_o_token
else:
query_start_index = idx_seq * seqlen
query_end_index = query_start_index + seqlen
x_offset = idx_seq * stride_x_seq
o_offset = idx_seq * stride_o_seq
if query_start_index == query_end_index:
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 3:
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 6:
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# With speculative decoding, the conv_state updates works in a sliding
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens - VAL >= 0)[:, None]
& (idx_tokens - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
# Get the state from the initial_state_idx
# cache_idx
conv_states_offset = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
current_last_index).to(tl.int64)
conv_state_ptrs_target = (
conv_state_ptr +
(conv_states_offset * stride_conv_state_seq) # Offset from seq
+ (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens * stride_conv_state_tok)[:, None]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
# STEP 3: init accumulator
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 5:
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 6:
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 5:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 6:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
matrix_x = col4
elif j == 5:
matrix_w = w_col5
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
elif KERNEL_WIDTH == 5:
col0 = col1
col1 = col2
col2 = col3
col3 = matrix_x
elif KERNEL_WIDTH == 6:
col0 = col1
col1 = col2
col2 = col3
col3 = col4
col4 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (idx_feats < dim
) # token-index # feature-index
o_ptrs = (o_ptr + o_offset + idx_token * stride_o_token +
(idx_feats * stride_o_dim))
tl.store(o_ptrs, acc, mask=mask_1d)
# TODO copied from vllm and it needs to be optimized
def original_causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
activation: bool | str | None = None,
conv_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
validate_data=False,
):
"""
x: Input tensor which can take the following shapes:
- `[batch, dim]` - single token prediction
- `[batch, dim, seqlen]` - single or multiple tokens prediction
- `[num_tokens, dim]` - continuous batching, where num_tokens is
the total tokens of all sequences in that batch
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
This is used in speculative decoding, where the conv_state is updated
in a sliding window manner.
query_start_loc: (batch + 1,) int32
If not None, the inputs is given in a varlen fashion and this indicates
the starting index of each sequence in the batch.
max_query_len: int
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if conv_state_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
if validate_data:
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
original_x_dtype = x.dtype
x = x.to(conv_state.dtype)
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
if query_start_loc is None:
batch, dim, seqlen = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(-2) == 1, (
f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
)
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch, ) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
stride_w_dim, stride_w_width = weight.stride()
if query_start_loc is None:
# X (batch, dim, seqlen)
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
else:
# X (dim, cu_seqlen)
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = (conv_state_indices.stride(0)
if conv_state_indices is not None else 0)
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)
_original_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
block_idx_last_scheduled_token,
initial_state_idx,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
)
if unsqueeze:
out = out.squeeze(-1)
return out.to(original_x_dtype)
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
@@ -392,6 +876,8 @@ def causal_conv1d_update_npu(
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1,
intermediate_conv_window: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
@@ -421,6 +907,19 @@ def causal_conv1d_update_npu(
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if query_start_loc is not None:
return original_causal_conv1d_update(
x=x,
conv_state=conv_state,
weight=weight,
bias=bias,
activation=activation,
conv_state_indices=conv_state_indices,
num_accepted_tokens=num_accepted_tokens,
query_start_loc=query_start_loc,
max_query_len=max_query_len,
validate_data=validate_data)
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None

View File

@@ -129,3 +129,28 @@
# Future Plan:
# Remove this patch when adapted vllm version contains the above PR.
#
# ** File: worker/patch_qwen3_next_mtp.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.utils.bind_kv_cache`
# Why:
# 'bind_kv_cache' func will raise an exception when current_platform is npu.
# How
# Replace with a new bind_kv_cache.
# Skip the raise.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
# ** File: worker/patch_module.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
# 'torch.argsort' func of npu does not support bool.
# How
# Replace with a new torch.argsort that will cast the input to torch.int32.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# Future Plan:
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
#

View File

@@ -32,3 +32,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa

View File

@@ -0,0 +1,34 @@
import torch
# torch_npu.argsort does not sipport bool now, it will support it in the future.
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)
class _TorchWrapper:
def __init__(self):
self._raw_torch = torch
def __getattr__(self, name):
if name == "argsort":
return _argsort
else:
return getattr(self._raw_torch, name)
_is_patched = False
# patch argsort only for torch in gdn_attn
def patch_torch_npu_argsort():
global _is_patched
if not _is_patched:
import vllm.v1.attention.backends.gdn_attn as gdn_attn
gdn_attn.torch = _TorchWrapper()
_is_patched = True

View File

@@ -0,0 +1,52 @@
import torch
import vllm.v1.worker.utils as utils
from vllm.attention.layer import Attention
from vllm.v1.worker.utils import defaultdict, extract_layer_index
# Without this patch, it will raise an exception when initialize kv_cache.
# TODO To remove the patch, we need check why the original bind_kv_cache raises an NotImplementedError.
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
# remove some codes for the typical case of encoder-decoder model, e.g., bart.
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
utils.bind_kv_cache = bind_kv_cache

View File

@@ -45,7 +45,9 @@ _MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"DeepseekV32ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM":
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
}
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
@@ -200,15 +202,16 @@ class MtpProposer(Proposer):
process_weights_after_loading(self.model, draft_model_config,
target_device)
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):

View File

@@ -107,8 +107,7 @@ from vllm.v1.worker.ec_connector_model_runner_mixin import \
ECConnectorModelRunnerMixin
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
gather_mm_placeholders,
from vllm.v1.worker.utils import (AttentionGroup, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
@@ -138,6 +137,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -619,8 +619,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
# Only relevant for multimodal models
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
@@ -1808,17 +1808,26 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
# For chunked prefills, use -1 as mask rather than 0, as guided
# decoding may rollback speculative tokens.
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
# For DECODE only cuda graph of some attention backends (e.g., GDN).
self.num_decode_draft_tokens.np[:
num_reqs] = num_decode_draft_tokens
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu()
# save logits_indices for pcp spec decode usage
self.logits_indices = logits_indices
@@ -1983,11 +1992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
builder = attn_group.get_metadata_builder()
if isinstance(builder, GDNAttentionMetadataBuilder):
if use_spec_decode:
patch_torch_npu_argsort()
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_decode_draft_tokens_cpu=self.num_draft_tokens.
gpu[:num_reqs],
num_decode_draft_tokens_cpu=self.
num_decode_draft_tokens.cpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
@@ -3485,6 +3495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
from vllm.v1.worker.utils import bind_kv_cache
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)