[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477)

### What this PR does / why we need it?

Supported to use full-graph with Qwen3-Next-MTP.

In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main
model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp
model.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

We changed the test of Qwen3-Next-MTP in
`tests/e2e/multicard/test_qwen3_next.py` to make it a test of
`FULL_DECODE_ONLY`. Then run `pytest -s
tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`.

And this test passed.

```text
.

================================================================================================================================= warnings summary =================================================================================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) =====================================================================================================================
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-04 12:03:21 +08:00
committed by GitHub
parent fd4b4fd06f
commit 363ac1b80f
4 changed files with 42 additions and 32 deletions

View File

@@ -894,9 +894,6 @@ class NPUModelRunner(GPUModelRunner):
self.logits_indices = logits_indices
# Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
self.spec_decode_common_attn_metadata = None
if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = (
@@ -991,7 +988,8 @@ class NPUModelRunner(GPUModelRunner):
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
@@ -1822,7 +1820,11 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "mtp":
attn_state = AscendAttentionState.SpecDecoding
# `AscendAttentionState.SpecDecoding` is only designed for mla
if self.vllm_config.model_config.use_mla:
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],