[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477)

### What this PR does / why we need it?

Supported to use full-graph with Qwen3-Next-MTP.

In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main
model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp
model.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

We changed the test of Qwen3-Next-MTP in
`tests/e2e/multicard/test_qwen3_next.py` to make it a test of
`FULL_DECODE_ONLY`. Then run `pytest -s
tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`.

And this test passed.

```text
.

================================================================================================================================= warnings summary =================================================================================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) =====================================================================================================================
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-04 12:03:21 +08:00
committed by GitHub
parent fd4b4fd06f
commit 363ac1b80f
4 changed files with 42 additions and 32 deletions

View File

@@ -34,7 +34,6 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = ["Qwen/Qwen3-Next-80B-A3B-Instruct"]
# TODO: add full decode only (when ready)
@pytest.mark.parametrize("model_name", MODELS)
def test_qwen3_next_mtp_acceptance_tp4(model_name):
golden = [0.85, 0.46, 0.19]
@@ -55,6 +54,7 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
distributed_executor_backend="mp",
disable_log_stats=False,
speculative_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"method": "qwen3_next_mtp",
"num_speculative_tokens": 3,
},
@@ -88,6 +88,8 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
cleanup_dist_env_and_memory()
# FIXME: When applying `FULL_DECODE_ONLY` in this e2e, ci will fail.
# The failure can not be reproduced locally.
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])

View File

@@ -293,7 +293,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
)
attn_metadata.attn_state = attn_state

View File

@@ -30,6 +30,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
@@ -66,6 +68,24 @@ class MtpProposer(EagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: Union[nn.Module, ACLGraphWrapper]
# update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens):
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, num_tokens)
else:
update_mla_attn_params(self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream, forward_context,
num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@@ -141,7 +161,7 @@ class MtpProposer(EagleProposer):
num_tokens_across_dp,
with_prefill,
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
if self.use_async_scheduling:
if not self.use_cuda_graph:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
@@ -185,8 +205,10 @@ class MtpProposer(EagleProposer):
:num_reqs * self.decode_threshold]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
# `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention.
attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill
attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.SpecDecoding)
common_attn_metadata, attn_state)
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
@@ -222,17 +244,9 @@ class MtpProposer(EagleProposer):
hidden_states=previous_hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
not forward_context.capturing and not self.use_sparse:
self._update_full_graph_params(forward_context, num_tokens)
if self.enable_shared_expert_dp:
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
positions, True)
@@ -654,7 +668,7 @@ class MtpProposer(EagleProposer):
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
if not self.use_cuda_graph:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
@@ -721,17 +735,9 @@ class MtpProposer(EagleProposer):
positions=positions,
hidden_states=hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_input_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse:
self._update_full_graph_params(forward_context,
num_input_tokens)
if self.enable_shared_expert_dp:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(

View File

@@ -894,9 +894,6 @@ class NPUModelRunner(GPUModelRunner):
self.logits_indices = logits_indices
# Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
self.spec_decode_common_attn_metadata = None
if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = (
@@ -991,7 +988,8 @@ class NPUModelRunner(GPUModelRunner):
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
@@ -1822,7 +1820,11 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "mtp":
attn_state = AscendAttentionState.SpecDecoding
# `AscendAttentionState.SpecDecoding` is only designed for mla
if self.vllm_config.model_config.use_mla:
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],